diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 13fe1aaaa..f57385813 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -12,16 +12,46 @@ Changelog 0.26.0 (unreleased) ------------------- + +.. warning:: + + This release contains **breaking changes** as part of the context-first architecture migration. + Please read the :ref:`migration_guide` before upgrading. + +Breaking Changes +^^^^^^^^^^^^^^^^ +- **Context-first architecture**: All ORM state now lives in ``TortoiseContext`` instances +- **Removed** legacy test classes: ``test.TestCase``, ``test.IsolatedTestCase``, ``test.TruncationTestCase``, ``test.SimpleTestCase`` +- **Removed** legacy test helpers: ``initializer()``, ``finalizer()``, ``env_initializer()``, ``getDBConfig()`` +- **Changed** ``Tortoise.init()`` now returns ``TortoiseContext`` (previously returned ``None``) +- **Changed** Multiple separate ``asyncio.run()`` calls in sequence require explicit context management due to ContextVar scoping (uncommon pattern, see migration guide). The typical single ``asyncio.run(main())`` pattern continues to work unchanged. + +Added +^^^^^ +- ``TortoiseContext`` - explicit context manager for ORM state +- ``tortoise_test_context()`` - modern pytest fixture helper for test isolation +- ``get_connection(alias)`` - function to get connection by alias from current context +- ``get_connections()`` - function to get the ConnectionHandler from current context +- ``Tortoise.close_connections()`` - class method to close all connections +- ``Tortoise.is_inited()`` - explicit method version of ``Tortoise._inited`` property + +Changed +^^^^^^^ +- Framework integrations (FastAPI, Starlette, Sanic, etc.) now use ``Tortoise.close_connections()`` internally +- ``ConnectionHandler`` now uses instance-based ContextVar storage (each context has isolated connections) +- ``Tortoise.apps`` and ``Tortoise._inited`` now use ``classproperty`` descriptor (no metaclass) +- feat: foreignkey to model type (#2027) + +Deprecated +^^^^^^^^^^ +- ``from tortoise import connections`` - use ``get_connection()`` / ``get_connections()`` functions instead (still works but deprecated) + Fixed ^^^^^ - Fix ``AttributeError`` when using ``tortoise-orm`` with Nuitka-compiled Python code (#2053) - Fix 'Self' in python standard library typing.py, but tortoise/model.py required it in 'typing_extensions' (#2051) - Fix annotations being selected in ValuesListQuery despite not specified in `.values_list` fields list (#2059) -Changed -^^^^^ -- feat: foreignkey to model type (#2027) - 0.25 ==== diff --git a/conftest.py b/conftest.py index 4f0182764..7d0dfae30 100644 --- a/conftest.py +++ b/conftest.py @@ -1,13 +1,20 @@ +""" +Pytest configuration for Tortoise ORM tests. + +Uses function-scoped fixtures for true test isolation. +""" + import os import pytest +import pytest_asyncio -from tortoise.contrib.test import finalizer, initializer +from tortoise.context import tortoise_test_context @pytest.fixture(scope="session", autouse=True) -def initialize_tests(request): - # Reduce the default timeout for psycopg because the tests become very slow otherwise +def configure_psycopg(): + """Configure psycopg timeout for faster tests.""" try: from tortoise.backends.psycopg import PsycopgClient @@ -15,6 +22,195 @@ def initialize_tests(request): except ImportError: pass + +# ============================================================================ +# HELPER FUNCTIONS +# ============================================================================ + + +async def _truncate_all_tables(ctx) -> None: + """Truncate all tables in the given context.""" + if ctx.apps: + for model in ctx.apps.get_models_iterable(): + quote_char = model._meta.db.query_class.SQL_CONTEXT.quote_char + await model._meta.db.execute_script( + f"DELETE FROM {quote_char}{model._meta.db_table}{quote_char}" # nosec + ) + + +# ============================================================================ +# PYTEST FIXTURES FOR TESTS +# These fixtures provide different isolation patterns for test scenarios +# ============================================================================ + + +@pytest_asyncio.fixture(scope="module") +async def db_module(): + """ + Module-scoped fixture: Creates TortoiseContext once per test module. + + This is the base fixture that creates the database schema once per module. + Other fixtures build on top of this for different isolation strategies. + + Note: Uses connection_label="models" to match standard test infrastructure. + """ db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") - initializer(["tests.testmodels"], db_url=db_url) - request.addfinalizer(finalizer) + async with tortoise_test_context( + modules=["tests.testmodels"], + db_url=db_url, + app_label="models", + connection_label="models", + ) as ctx: + yield ctx + + +@pytest_asyncio.fixture(scope="function") +async def db(db_module): + """ + Function-scoped fixture with transaction rollback cleanup. + + Each test runs inside a transaction that gets rolled back at the end, + providing isolation without the overhead of schema recreation. + + For databases that don't support transactions (e.g., MySQL MyISAM), + falls back to truncation cleanup. + + This is the FASTEST isolation method - use for most tests. + + Usage: + @pytest.mark.asyncio + async def test_something(db): + obj = await Model.create(name="test") + assert obj.id is not None + # Changes are rolled back after test + """ + # Get connection from the context using its default connection + conn = db_module.db() + + # Check if the database supports transactions + if conn.capabilities.supports_transactions: + # Start a savepoint/transaction + transaction = conn._in_transaction() + await transaction.__aenter__() + + try: + yield db_module + finally: + # Rollback the transaction (discards all changes made during test) + class _RollbackException(Exception): + pass + + await transaction.__aexit__(_RollbackException, _RollbackException(), None) + else: + # For databases without transaction support (e.g., MyISAM), + # fall back to truncation cleanup + yield db_module + await _truncate_all_tables(db_module) + + +@pytest_asyncio.fixture(scope="function") +async def db_simple(db_module): + """ + Function-scoped fixture with NO cleanup between tests. + + Tests share state - data from one test persists to the next within the module. + Use ONLY for read-only tests or tests that manage their own cleanup. + + Usage: + @pytest.mark.asyncio + async def test_read_only(db_simple): + # Read-only operations, no writes + config = get_config() + assert "host" in config + """ + yield db_module + + +@pytest_asyncio.fixture(scope="function") +async def db_isolated(): + """ + Function-scoped fixture with full database recreation per test. + + Creates a completely fresh database for EACH test. This is the SLOWEST + method but provides maximum isolation. + + Use when: + - Testing database creation/dropping + - Tests need custom model modules + - Tests must have completely clean state + + Usage: + @pytest.mark.asyncio + async def test_with_fresh_db(db_isolated): + # Completely fresh database + ... + """ + db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") + async with tortoise_test_context( + modules=["tests.testmodels"], + db_url=db_url, + app_label="models", + connection_label="models", + ) as ctx: + yield ctx + + +@pytest_asyncio.fixture(scope="function") +async def db_truncate(db_module): + """ + Function-scoped fixture with table truncation cleanup. + + After each test, all tables are truncated (DELETE FROM). + Faster than db_isolated but slower than db (transaction rollback). + + Use when testing transaction behavior (can't use rollback for cleanup). + + Usage: + @pytest.mark.asyncio + async def test_with_transactions(db_truncate): + async with in_transaction(): + await Model.create(name="test") + # Table truncated after test + """ + yield db_module + await _truncate_all_tables(db_module) + + +# ============================================================================ +# HELPER FIXTURES +# ============================================================================ + + +def make_db_fixture( + modules: list[str], app_label: str = "models", connection_label: str = "models" +): + """ + Factory function to create custom db fixtures with different modules. + + Use this in subdirectory conftest.py files for tests that need + custom model modules. + + Example usage in tests/fields/conftest.py: + db_array = make_db_fixture(["tests.fields.test_array"]) + + Args: + modules: List of module paths to discover models from. + app_label: The app label for the models, defaults to "models". + connection_label: The connection alias name, defaults to "models". + + Returns: + An async fixture function. + """ + + @pytest_asyncio.fixture(scope="function") + async def _db_fixture(): + db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") + async with tortoise_test_context( + modules=modules, + db_url=db_url, + app_label=app_label, + connection_label=connection_label, + ) as ctx: + yield ctx + + return _db_fixture diff --git a/docs/conf.py b/docs/conf.py index 5062de767..dc3829e78 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -7,7 +7,6 @@ import json # -- Path setup -------------------------------------------------------------- - # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. diff --git a/docs/connections.rst b/docs/connections.rst index 344c89fec..31813cf8d 100644 --- a/docs/connections.rst +++ b/docs/connections.rst @@ -4,21 +4,72 @@ Connections =========== -This document describes how to access the underlying connection object (:ref:`BaseDBAsyncClient`) for the aliases defined -as part of the DB config passed to the :meth:`Tortoise.init` call. +This document describes how to access database connections in Tortoise ORM. -Below is a simple code snippet which shows how the interface can be accessed: +.. contents:: + :local: + :depth: 2 -.. code-block:: python3 +Accessing Connections +===================== - # connections is a singleton instance of the ConnectionHandler class and serves as the - # entrypoint to access all connection management APIs. - from tortoise import connections +Tortoise ORM provides multiple ways to access database connections: +Via ``Tortoise`` Class (Recommended) +------------------------------------ + +The simplest way to access connections: + +.. code-block:: python + + from tortoise import Tortoise + + # Get a specific connection by alias + conn = Tortoise.get_connection("default") + + # Execute raw queries + result = await conn.execute_query('SELECT * FROM "user"') + +Via Helper Functions +-------------------- + +For more direct access to the connection handler: + +.. code-block:: python + + from tortoise.connection import get_connection, get_connections + + # Get a specific connection + conn = get_connection("default") + + # Get the connection handler (access to all connections) + handler = get_connections() + all_connections = handler.all() + +Via Context (Advanced) +---------------------- + +When working with explicit contexts: + +.. code-block:: python + + from tortoise.context import TortoiseContext + + async with TortoiseContext() as ctx: + await ctx.init(db_url="sqlite://:memory:", modules={"models": ["myapp.models"]}) + + # Access connections via context + conn = ctx.connections.get("default") + +Connection Configuration +======================== + +Connections are configured when calling ``Tortoise.init()``: + +.. code-block:: python - # Assume that this is the Tortoise configuration used await Tortoise.init( - { + config={ "connections": { "default": { "engine": "tortoise.backends.sqlite", @@ -26,33 +77,135 @@ Below is a simple code snippet which shows how the interface can be accessed: } }, "apps": { - "events": {"models": ["__main__"], "default_connection": "default"} + "models": {"models": ["__main__"], "default_connection": "default"} }, } ) - conn: BaseDBAsyncClient = connections.get("default") - try: - await conn.execute_query('SELECT * FROM "event"') - except OperationalError: - print("Expected it to fail") +Or using a DB URL: + +.. code-block:: python -.. important:: - The :ref:`tortoise.connection.ConnectionHandler` class has been implemented with the singleton - pattern in mind and so when the ORM initializes, a singleton instance of this class - ``tortoise.connection.connections`` is created automatically and lives in memory up until the lifetime of the app. - Any attempt to modify or override its behaviour at runtime is risky and not recommended. + await Tortoise.init( + db_url="sqlite://example.sqlite3", + modules={"models": ["__main__"]} + ) +Multiple Databases +================== -Please refer to :ref:`this example` for a detailed demonstration of how this API can be used -in practice. +Configure multiple connections for different databases: +.. code-block:: python + + await Tortoise.init( + config={ + "connections": { + "default": "sqlite://primary.sqlite3", + "secondary": "postgres://user:pass@localhost:5432/secondary", + }, + "apps": { + "primary_models": { + "models": ["myapp.primary_models"], + "default_connection": "default", + }, + "secondary_models": { + "models": ["myapp.secondary_models"], + "default_connection": "secondary", + } + }, + } + ) + + # Access specific connections + primary_conn = Tortoise.get_connection("default") + secondary_conn = Tortoise.get_connection("secondary") + +Please refer to :ref:`this example` for a detailed demonstration. + +Closing Connections +=================== + +Always close connections when shutting down your application: + +.. code-block:: python + + # Close all connections + await Tortoise.close_connections() + + # Or via helper function + from tortoise.connection import get_connections + await get_connections().close_all() + +In framework integrations, this is typically handled automatically on shutdown. + +Connection Lifecycle +==================== + +Connections are created lazily when first accessed and are managed by the +``ConnectionHandler`` class. Each ``TortoiseContext`` has its own ``ConnectionHandler``, +providing isolation between different contexts (useful for testing). + +.. code-block:: python + + # Connection is created on first access + conn = Tortoise.get_connection("default") + + # Same connection is returned on subsequent calls + conn2 = Tortoise.get_connection("default") + assert conn is conn2 + + # Closing discards the connection + await Tortoise.close_connections() + + # Next access creates a new connection + conn3 = Tortoise.get_connection("default") + assert conn is not conn3 API Reference -=========== +============= .. _connection_handler: -.. automodule:: tortoise.connection +Helper Functions +---------------- + +.. autofunction:: tortoise.connection.get_connection + +.. autofunction:: tortoise.connection.get_connections + +ConnectionHandler Class +----------------------- + +.. autoclass:: tortoise.connection.ConnectionHandler :members: - :undoc-members: \ No newline at end of file + :undoc-members: + +Migration from Legacy API +========================= + +If you're upgrading from an older version that used the ``connections`` singleton, +see the :ref:`migration_guide` for details. + +.. note:: + + The ``connections`` singleton still works but is deprecated. It now acts as a + proxy that delegates to the current context's ``ConnectionHandler``. New code + should use ``get_connection()`` / ``get_connections()`` or ``Tortoise.get_connection()``. + +**Quick reference:** + +.. list-table:: API Migration + :header-rows: 1 + :widths: 50 50 + + * - Old Pattern (Deprecated) + - New Pattern + * - ``from tortoise import connections`` + - ``from tortoise.connection import get_connections`` + * - ``connections.get("alias")`` + - ``Tortoise.get_connection("alias")`` + * - ``connections.all()`` + - ``get_connections().all()`` + * - ``connections.close_all()`` + - ``Tortoise.close_connections()`` diff --git a/docs/contrib/unittest.rst b/docs/contrib/unittest.rst index 1156b98fa..7540e2eeb 100644 --- a/docs/contrib/unittest.rst +++ b/docs/contrib/unittest.rst @@ -1,133 +1,233 @@ .. _unittest: -================ -UnitTest support -================ +============== +Testing Support +============== -Tortoise ORM includes its own helper utilities to assist in unit tests. +Tortoise ORM provides testing utilities designed for pytest with true test isolation. +Each test gets its own database context, ensuring tests don't interfere with each other. -Usage -===== +.. contents:: + :local: + :depth: 2 -.. code-block:: python3 +Quick Start +=========== - from tortoise.contrib import test +1. Create a ``conftest.py`` file in your tests directory: - class TestSomething(test.TestCase): - def test_something(self): - ... +.. code-block:: python - async def test_something_async(self): - ... + import os + import pytest_asyncio + from tortoise.contrib.test import tortoise_test_context - @test.skip('Skip this') - def test_skip(self): - ... + @pytest_asyncio.fixture + async def db(): + """Provide isolated database context for each test.""" + db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") + async with tortoise_test_context(["myapp.models"], db_url=db_url) as ctx: + yield ctx - @test.expectedFailure - def test_something(self): - ... +2. Write your tests as async functions: +.. code-block:: python -To get ``test.TestCase`` to work as expected, you need to configure your test environment setup and teardown to call the following: + import pytest + from myapp.models import User -.. code-block:: python3 + @pytest.mark.asyncio + async def test_create_user(db): + user = await User.create(name="Test User", email="test@example.com") + assert user.id is not None + assert user.name == "Test User" + + @pytest.mark.asyncio + async def test_filter_users(db): + await User.create(name="Alice") + await User.create(name="Bob") + + users = await User.filter(name="Alice") + assert len(users) == 1 + assert users[0].name == "Alice" - from tortoise.contrib.test import initializer, finalizer +3. Run your tests: - # In setup - initializer(['module.a', 'module.b.c']) - # With optional db_url, app_label and loop parameters - initializer(['module.a', 'module.b.c'], db_url='...', app_label="someapp", loop=loop) - # Or env-var driven → See Green test runner section below. - env_initializer() +.. code-block:: bash - # In teardown - finalizer() + pytest tests/ -v +``tortoise_test_context`` Reference +=================================== -On the DB_URL it should follow the following standard: +The ``tortoise_test_context`` function creates an isolated ORM context for testing: - TORTOISE_TEST_DB=sqlite:///tmp/test-{}.sqlite - TORTOISE_TEST_DB=postgres://postgres:@127.0.0.1:5432/test_{} +.. code-block:: python + from tortoise.contrib.test import tortoise_test_context -The ``{}`` is a string-replacement parameter, that will create a randomized database name. -This is currently required for ``test.IsolatedTestCase`` to function. -If you don't use ``test.IsolatedTestCase`` then you can give an absolute address. -The SQLite in-memory ``:memory:`` database will always work, and is the default. + async with tortoise_test_context( + modules=["myapp.models"], # Required: List of model modules + db_url="sqlite://:memory:", # Optional: Database URL (default: sqlite://:memory:) + app_label="models", # Optional: App label (default: "models") + connection_label="default", # Optional: Connection alias (default: "default") + ) as ctx: + # Your test code here + pass -.. rst-class:: emphasize-children +**Parameters:** -Test Runners -============ +- ``modules`` (list): List of module paths containing your models. Required. +- ``db_url`` (str): Database connection URL. Defaults to ``sqlite://:memory:``. +- ``app_label`` (str): Label for the app in the ORM registry. Defaults to ``"models"``. +- ``connection_label`` (str): Alias for the database connection. Defaults to ``"default"``. -Green ------ +The context manager: -In your ``.green`` file: +1. Creates a fresh ``TortoiseContext`` +2. Initializes the ORM with the given configuration +3. Generates database schemas +4. Yields the context for your test +5. Closes all connections on exit -.. code-block:: ini +Testing with Multiple Databases +=============================== - initializer = tortoise.contrib.test.env_initializer - finalizer = tortoise.contrib.test.finalizer +For tests that require multiple database connections: -And then define the ``TORTOISE_TEST_MODULES`` environment variable with a comma separated list of module paths. +.. code-block:: python -Furthermore, you may set the database configuration parameter as an environment variable (defaults to ``sqlite://:memory:``): + import pytest_asyncio + from tortoise.context import TortoiseContext - TORTOISE_TEST_DB=sqlite:///tmp/test-{}.sqlite - TORTOISE_TEST_DB=postgres://postgres:@127.0.0.1:5432/test_{} + @pytest_asyncio.fixture + async def multi_db(): + """Fixture for testing with multiple databases.""" + async with TortoiseContext() as ctx: + await ctx.init(config={ + "connections": { + "primary": "sqlite://:memory:", + "secondary": "sqlite://:memory:", + }, + "apps": { + "models": { + "models": ["myapp.models"], + "default_connection": "primary", + }, + "archive": { + "models": ["myapp.archive_models"], + "default_connection": "secondary", + } + } + }) + await ctx.generate_schemas() + yield ctx +Testing Database Capabilities +============================= -Py.test -------- +Use ``requireCapability`` to skip tests based on database capabilities: -.. note:: +.. code-block:: python - pytest 5.4.0 & 5.4.1 has a bug that stops it from working with async test cases. You may have to install ``pytest>=5.4.2`` to get it to work. + from tortoise.contrib.test import requireCapability -Run the initializer and finalizer in your ``conftest.py`` file: + @pytest.mark.asyncio + @requireCapability(dialect="postgres") + async def test_postgres_specific_feature(db): + """This test only runs on PostgreSQL.""" + # Test postgres-specific functionality + pass -.. code-block:: python3 + @pytest.mark.asyncio + @requireCapability(dialect="sqlite") + async def test_sqlite_specific_feature(db): + """This test only runs on SQLite.""" + pass - import os - import pytest - from tortoise.contrib.test import finalizer, initializer +Environment Variables +===================== + +Configure your test database via environment variables: + +.. code-block:: bash + + # SQLite (default) + export TORTOISE_TEST_DB="sqlite://:memory:" + + # PostgreSQL + export TORTOISE_TEST_DB="postgres://user:pass@localhost:5432/testdb" - @pytest.fixture(scope="session", autouse=True) - def initialize_tests(request): - db_url = os.environ.get("TORTOISE_TEST_DB", "sqlite://:memory:") - initializer(["tests.testmodels"], db_url=db_url, app_label="models") - request.addfinalizer(finalizer) + # MySQL + export TORTOISE_TEST_DB="mysql://user:pass@localhost:3306/testdb" +Using ``{}`` in the URL creates randomized database names (useful for parallel testing): -Nose2 ------ +.. code-block:: bash + + export TORTOISE_TEST_DB="sqlite:///tmp/test-{}.sqlite" + export TORTOISE_TEST_DB="postgres://user:pass@localhost:5432/test_{}" + +Utility Functions +================= -Load the plugin ``tortoise.contrib.test.nose2`` either via command line:: +truncate_all_models +------------------- - nose2 --plugin tortoise.contrib.test.nose2 --db-module tortoise.tests.testmodels +Truncate all model tables in the current context: -Or via the config file: +.. code-block:: python + + from tortoise.contrib.test import truncate_all_models -.. code-block:: ini + @pytest.mark.asyncio + async def test_with_truncation(db): + # Create some data + await User.create(name="Test") + + # Truncate all tables + await truncate_all_models() - [unittest] - plugins = tortoise.contrib.test.nose2 + # Tables are now empty + count = await User.all().count() + assert count == 0 - [tortoise] - # Must specify at least one module path - db-module = - tests.testmodels - # You can optionally override the db_url here - db-url = sqlite://testdb-{}.sqlite +Migration from Legacy Test Classes +================================== +If you're upgrading from the legacy ``test.TestCase`` classes, see the +:ref:`migration_guide` for detailed migration instructions. + +**Quick reference:** + +.. list-table:: Migration Mapping + :header-rows: 1 + :widths: 40 60 + + * - Legacy (Removed) + - Modern Replacement + * - ``test.TestCase`` + - pytest + ``db`` fixture + * - ``test.IsolatedTestCase`` + - pytest + ``db`` fixture (isolation is default) + * - ``test.TruncationTestCase`` + - pytest + ``db`` fixture + ``truncate_all_models()`` + * - ``test.SimpleTestCase`` + - pytest + ``db`` fixture + * - ``initializer()`` + - ``tortoise_test_context()`` + * - ``finalizer()`` + - (automatic with context manager) + * - ``self.assertEqual(a, b)`` + - ``assert a == b`` + * - ``self.assertIn(a, b)`` + - ``assert a in b`` + * - ``self.assertRaises(Exc)`` + - ``pytest.raises(Exc)`` Reference ========= .. automodule:: tortoise.contrib.test - :members: - :undoc-members: + :members: tortoise_test_context, truncate_all_models, requireCapability :show-inheritance: diff --git a/docs/migration_guide.rst b/docs/migration_guide.rst new file mode 100644 index 000000000..570dfe112 --- /dev/null +++ b/docs/migration_guide.rst @@ -0,0 +1,462 @@ +.. _migration_guide: + +==================================== +Migration Guide: Tortoise 1.0 +==================================== + +This guide covers the breaking changes and migration steps for upgrading to Tortoise ORM 1.0+ +which introduces a isolated-context architecture for improved test isolation and cleaner state management. + +.. contents:: + :local: + :depth: 2 + +Overview +======== + +Tortoise ORM 1.0 introduces a **isolated-context architecture** that: + +- Removes global state (``_default_context``, metaclass) +- Uses ``TortoiseContext`` as the single source of truth +- Provides test isolation with ``tortoise_test_context()`` +- Simplifies connection management + +Most application code continues to work unchanged. The main changes affect: + +1. Direct access to the ``connections`` singleton +2. Test infrastructure (``test.TestCase``, ``initializer``, etc.) +3. Multiple ``asyncio.run()`` call patterns + +Quick Reference +=============== + +.. list-table:: API Changes + :header-rows: 1 + :widths: 40 60 + + * - Old Pattern + - New Pattern + * - ``from tortoise import connections`` (deprecated) + - ``from tortoise.connection import get_connection, get_connections`` + * - ``connections.get("default")`` (still works) + - ``Tortoise.get_connection("default")`` or ``get_connection("default")`` + * - ``connections.close_all()`` (still works) + - ``Tortoise.close_connections()`` + * - ``test.TestCase`` (removed) + - pytest + ``db`` fixture + * - ``initializer()`` / ``finalizer()`` (removed) + - ``tortoise_test_context()`` + +What Stays the Same +=================== + +The following APIs continue to work unchanged: + +.. code-block:: python + + # Initialization (unchanged) + await Tortoise.init(config=...) + await Tortoise.init(db_url="...", modules={...}) + await Tortoise.generate_schemas() + + # Accessing apps (unchanged) + Tortoise.apps + Tortoise._inited + + # Model operations (unchanged) + await User.create(name="test") + await User.filter(name="test").first() + + # Framework integrations (unchanged for users) + # FastAPI, Starlette, Sanic, etc. + +Connection Access Changes +========================= + +Old Pattern (Deprecated) +------------------------ + +.. code-block:: python + + from tortoise import connections + + conn = connections.get("default") + await connections.close_all() + +New Pattern +----------- + +.. code-block:: python + + from tortoise import Tortoise + # Or: from tortoise.connection import get_connection, get_connections + + # Get a single connection + conn = Tortoise.get_connection("default") + + # Get the connection handler + handler = get_connections() + all_conns = handler.all() + + # Close all connections + await Tortoise.close_connections() + +Test Migration +============== + +The legacy test base classes (``TestCase``, ``IsolatedTestCase``, etc.) and helper +functions (``initializer``, ``finalizer``) have been replaced with a pytest-based +approach using ``tortoise_test_context()``. + +Old Test Pattern +---------------- + +.. code-block:: python + + from tortoise.contrib import test + + class TestUser(test.TestCase): + async def test_create(self): + user = await User.create(name="Test") + self.assertEqual(user.name, "Test") + + async def test_filter(self): + await User.create(name="Test") + users = await User.filter(name="Test") + self.assertEqual(len(users), 1) + +With ``conftest.py``: + +.. code-block:: python + + from tortoise.contrib.test import initializer, finalizer + + @pytest.fixture(scope="session", autouse=True) + def initialize_tests(request): + initializer(["myapp.models"]) + request.addfinalizer(finalizer) + +New Test Pattern +---------------- + +.. code-block:: python + + import pytest + from tests.testmodels import User + + @pytest.mark.asyncio + async def test_create(db): + user = await User.create(name="Test") + assert user.name == "Test" + + @pytest.mark.asyncio + async def test_filter(db): + await User.create(name="Test") + users = await User.filter(name="Test") + assert len(users) == 1 + +With ``conftest.py``: + +.. code-block:: python + + import pytest_asyncio + from tortoise.contrib.test import tortoise_test_context + + @pytest_asyncio.fixture + async def db(): + async with tortoise_test_context(["myapp.models"]) as ctx: + yield ctx + +Migration Checklist +------------------- + +For each test file: + +1. Replace ``from tortoise.contrib import test`` with ``import pytest`` +2. Remove class wrapper (``class TestXxx(test.TestCase):``) +3. Add ``@pytest.mark.asyncio`` decorator to each async test +4. Add ``db`` fixture parameter to each test function +5. Replace assertion methods: + - ``self.assertEqual(a, b)`` → ``assert a == b`` + - ``self.assertIn(a, b)`` → ``assert a in b`` + - ``self.assertRaises(Exc)`` → ``pytest.raises(Exc)`` + - ``self.assertTrue(x)`` → ``assert x`` + - ``self.assertFalse(x)`` → ``assert not x`` + +Multiple ``asyncio.run()`` Calls (Uncommon Pattern) +=================================================== + +.. note:: + + This section only applies if you use multiple **separate** ``asyncio.run()`` calls + in sequence. The typical pattern of a single ``asyncio.run(main())`` that contains + all ORM operations continues to work unchanged. + +If you use multiple separate ``asyncio.run()`` calls (sometimes seen in scripts or REPL +sessions), the ContextVar that tracks ORM state is lost between runs due to Python's +ContextVar scoping rules. This pattern now requires explicit context management. + +As a fallback `_enable_global_fallback` on `Tortoise.init(...)` can be used to set created +context as global fallback. + +Old Pattern (No Longer Works) +----------------------------- + +.. code-block:: python + + import asyncio + from tortoise import Tortoise + + # Context is lost after asyncio.run() completes + asyncio.run(Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]})) + asyncio.run(User.create(name="test")) # FAILS: No context + +New Patterns +------------ + +**Option 1: Single asyncio.run (Recommended)** + +.. code-block:: python + + import asyncio + from tortoise import Tortoise + + async def main(): + await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]}) + await Tortoise.generate_schemas() + user = await User.create(name="test") + print(f"Created user: {user.id}") + await Tortoise.close_connections() + + asyncio.run(main()) + +**Option 2: Capture and Reuse Context** + +.. code-block:: python + + import asyncio + from tortoise import Tortoise + + # Tortoise.init() returns the context + ctx = asyncio.run(Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]})) + + # Re-enter context for subsequent runs + with ctx: + asyncio.run(Tortoise.generate_schemas()) + asyncio.run(User.create(name="test")) + +**Option 3: Explicit Context Manager** + +.. code-block:: python + + import asyncio + from tortoise.context import TortoiseContext + + with TortoiseContext() as ctx: + asyncio.run(ctx.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]})) + asyncio.run(ctx.generate_schemas()) + asyncio.run(User.create(name="test")) + +Using ``TortoiseContext`` Directly +================================== + +For advanced use cases (testing, multi-tenant applications), you can use +``TortoiseContext`` directly: + +.. code-block:: python + + from tortoise.context import TortoiseContext + + async def run_isolated(): + async with TortoiseContext() as ctx: + await ctx.init( + db_url="sqlite://:memory:", + modules={"models": ["myapp.models"]} + ) + await ctx.generate_schemas() + + # All ORM operations use this context + user = await User.create(name="test") + + # Context auto-closes on exit + +Benefits of ``TortoiseContext``: + +- **Test isolation**: Each context has independent connections and state +- **Multi-tenancy**: Different contexts can connect to different databases +- **No global state**: Clear ownership of ORM state +- **Automatic cleanup**: Connections close when context exits + +Framework Integration Changes +============================= + +If you use the built-in framework integrations (FastAPI, Starlette, etc.), no changes +are required. The integrations have been updated internally to use ``Tortoise.close_connections()`` +instead of ``connections.close_all()``. + +Multiple FastAPI Apps (Global Fallback) +--------------------------------------- + +When using ``RegisterTortoise`` with FastAPI, a global fallback context is enabled by default. +This allows Tortoise ORM to work correctly with ``asgi-lifespan`` (used in tests) where the +lifespan runs in a separate background task from the requests. + +If you run **multiple FastAPI apps** in the same process (e.g., in tests), you may encounter: + +.. code-block:: text + + ConfigurationError: Global context fallback is already enabled by another Tortoise.init() call. + +**Solution:** Disable global fallback for secondary apps and use explicit context access: + +.. code-block:: python + + # main_app.py - Primary app (uses global fallback) + from tortoise.contrib.fastapi import RegisterTortoise + + @asynccontextmanager + async def lifespan(app: FastAPI): + async with RegisterTortoise( + app, + db_url="sqlite://:memory:", + modules={"models": ["myapp.models"]}, + ): + yield + + app = FastAPI(lifespan=lifespan) + +.. code-block:: python + + # secondary_app.py - Secondary app (explicit context) + from tortoise.contrib.fastapi import RegisterTortoise + + @asynccontextmanager + async def lifespan(app: FastAPI): + async with RegisterTortoise( + app, + db_url="sqlite://:memory:", + modules={"models": ["myapp.models"]}, + _enable_global_fallback=False, # Disable global fallback + ): + yield + + app_secondary = FastAPI(lifespan=lifespan) + +In tests, access the secondary app's context explicitly via ``app.state``: + +.. code-block:: python + + @pytest.fixture + async def client_secondary(): + async with LifespanManager(app_secondary) as manager: + # Get context from app.state and enter it + ctx = app_secondary.state._tortoise_context + with ctx: # Make context current via contextvar + async with AsyncClient(app=app_secondary) as c: + yield c + +The ``_enable_global_fallback`` parameter: + +- ``True`` (default): Sets context as global fallback for cross-task access +- ``False``: Context only accessible via ``app.state._tortoise_context`` + +This is also available in ``Tortoise.init()`` (default ``False``) and +``TortoiseContext.init()`` (default ``False``). + +Custom Integration Migration +---------------------------- + +If you've written custom framework integrations: + +.. code-block:: python + + # Old + from tortoise import connections + + async def shutdown(): + await connections.close_all() + + # New + from tortoise import Tortoise + + async def shutdown(): + await Tortoise.close_connections() + +Removed APIs +============ + +The following APIs have been removed: + +- ``test.TestCase``, ``test.IsolatedTestCase``, ``test.TruncationTestCase`` +- ``test.SimpleTestCase`` +- ``test.initializer()``, ``test.finalizer()`` +- ``test.env_initializer()`` +- ``test.getDBConfig()`` + +Deprecated APIs +=============== + +The following APIs still work but are deprecated: + +- ``from tortoise import connections`` - use ``get_connection()`` / ``get_connections()`` instead + +Still Available +=============== + +The following APIs are still available and work as before: + +- ``init_memory_sqlite()`` decorator - for simple scripts +- ``MEMORY_SQLITE`` constant - ``"sqlite://:memory:"`` +- ``requireCapability()`` - for capability-based test skipping +- ``truncate_all_models()`` - for test cleanup + +Troubleshooting +=============== + +"No TortoiseContext is currently active" +---------------------------------------- + +This error occurs when trying to access ORM features without an active context. + +**Solutions:** + +1. Ensure ``Tortoise.init()`` was called before accessing models +2. If using multiple ``asyncio.run()`` calls, use context manager pattern +3. In tests, ensure the ``db`` fixture is being used + +"Global context fallback is already enabled" +-------------------------------------------- + +This error occurs when multiple ``Tortoise.init()`` or ``RegisterTortoise`` calls +try to enable global fallback simultaneously. + +**Solutions:** + +1. For multiple FastAPI apps, set ``_enable_global_fallback=False`` on secondary apps +2. Access secondary app's context explicitly via ``app.state._tortoise_context`` +3. See "Multiple FastAPI Apps (Global Fallback)" section above + +"ConfigurationError: Connections not initialized" +------------------------------------------------- + +This error occurs when trying to access connections before initialization. + +**Solution:** Ensure ``Tortoise.init()`` or ``ctx.init()`` has been called and awaited. + +Test isolation issues +--------------------- + +If tests are interfering with each other: + +1. Ensure using function-scoped ``db`` fixture (not session-scoped) +2. Use ``tortoise_test_context()`` which provides explicit isolation +3. Remove any ``@pytest.fixture(scope="session")`` that calls ``initializer()`` + +Getting Help +============ + +If you encounter issues during migration: + +1. Check the `GitHub Issues `_ +2. Review the `examples directory `_ +3. Ask in the `GitHub Discussions `_ diff --git a/docs/setup.rst b/docs/setup.rst index e28217e0a..330a6d139 100644 --- a/docs/setup.rst +++ b/docs/setup.rst @@ -79,3 +79,44 @@ To ensure connections are properly closed, make sure to call ``Tortoise.close_co await Tortoise.close_connections() The helper function ``tortoise.run_async()`` automatically ensures that connections are closed when your application terminates. + +.. _global_fallback: + +Global Context Fallback +======================= + +By default, Tortoise ORM uses Python's ``contextvars`` to track the active context. This works +well when ``Tortoise.init()`` is called from the same task that will execute queries. + +However, in some scenarios, initialization happens in a **different task** than where queries +run. For example: + +- ASGI lifespan handlers that run in a background task +- Framework setup code that spawns a separate initialization task +- Test harnesses that manage app lifecycle in background tasks + +In these cases, the context set in the initialization task is not visible to other tasks, +resulting in ``RuntimeError: No TortoiseContext is currently active``. + +To solve this, use the ``_enable_global_fallback`` parameter: + +.. code-block:: python3 + + await Tortoise.init( + db_url='sqlite://db.sqlite3', + modules={'models': ['app.models']}, + _enable_global_fallback=True, + ) + +When enabled, the context is stored in a global variable in addition to the contextvar, +making it accessible from any task in the process. + +**Important considerations:** + +- Only **one** global fallback context can be active at a time +- Attempting to enable global fallback when one is already set raises ``ConfigurationError`` +- For multiple isolated contexts, use explicit ``TortoiseContext()`` instances instead +- The global fallback is automatically cleared when ``Tortoise.close_connections()`` is called + +This parameter is also available in ``TortoiseContext.init()`` and framework integrations +like ``RegisterTortoise`` (where it defaults to ``True``). diff --git a/docs/sphinx_autodoc_typehints.py b/docs/sphinx_autodoc_typehints.py index 336b069fd..f8c0dde56 100644 --- a/docs/sphinx_autodoc_typehints.py +++ b/docs/sphinx_autodoc_typehints.py @@ -2,12 +2,12 @@ import sys import textwrap import typing -from typing import get_type_hints, TypeVar, Any, AnyStr, Generic, Union - -from sphinx.util import logging -from sphinx.util.inspect import signature as Signature, stringify_signature +from typing import Any, AnyStr, Generic, TypeVar, Union, get_type_hints import type_globals +from sphinx.util import logging +from sphinx.util.inspect import signature as Signature +from sphinx.util.inspect import stringify_signature try: from typing_extensions import Protocol @@ -212,7 +212,7 @@ def get_all_type_hints(obj, name): # Introspecting a slot wrapper will raise TypeError, and and some recursive type # definitions will cause a RecursionError (https://github.com/python/typing/issues/574). pass - except NameError as exc: + except NameError: try: rv = get_type_hints(obj, localns=type_globals.__dict__) except Exception as exc: diff --git a/docs/toc.rst b/docs/toc.rst index de1039114..fa35af600 100644 --- a/docs/toc.rst +++ b/docs/toc.rst @@ -10,6 +10,7 @@ Table Of Contents reference examples contrib + migration_guide CHANGELOG roadmap CONTRIBUTING diff --git a/docs/type_globals.py b/docs/type_globals.py index 02c7a22fe..b2d8b9fcf 100644 --- a/docs/type_globals.py +++ b/docs/type_globals.py @@ -1,3 +1,3 @@ from tortoise import * from tortoise.queryset import Q -from tortoise.backends.base.client import TransactionContext, TransactionalDBClient +from tortoise.backends.base.client import TransactionContext diff --git a/examples/fastapi/_tests.py b/examples/fastapi/_tests.py index d46564783..ce76b8072 100644 --- a/examples/fastapi/_tests.py +++ b/examples/fastapi/_tests.py @@ -57,8 +57,12 @@ async def client() -> ClientManagerType: @pytest.fixture(scope="module") async def client_east() -> ClientManagerType: + # app_east uses _enable_global_fallback=False, so we need to explicitly + # enter the context from app.state to make it current for tests async with client_manager(app_east) as c: - yield c + ctx = app_east.state._tortoise_context + with ctx: # Enter context to make it current via contextvar + yield c class UserTester: diff --git a/examples/fastapi/main.py b/examples/fastapi/main.py index b4bd35e7f..97a3ca7f3 100644 --- a/examples/fastapi/main.py +++ b/examples/fastapi/main.py @@ -7,7 +7,8 @@ from routers import router as users_router from examples.fastapi.config import register_orm -from tortoise import Tortoise, generate_config +from tortoise import Tortoise +from tortoise.backends.base.config_generator import generate_config from tortoise.contrib.fastapi import RegisterTortoise, tortoise_exception_handlers diff --git a/examples/fastapi/main_custom_timezone.py b/examples/fastapi/main_custom_timezone.py index e75151155..de59b1286 100644 --- a/examples/fastapi/main_custom_timezone.py +++ b/examples/fastapi/main_custom_timezone.py @@ -10,11 +10,14 @@ @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # app startup + # Disable global fallback since this is the secondary app in tests + # (main app already uses global fallback). Context is stored in app.state. async with register_orm( app, use_tz=False, timezone="Asia/Shanghai", add_exception_handlers=True, + _enable_global_fallback=False, ): # db connected yield diff --git a/examples/pytest/conftest.py b/examples/pytest/conftest.py new file mode 100644 index 000000000..8df37688c --- /dev/null +++ b/examples/pytest/conftest.py @@ -0,0 +1,10 @@ +import pytest_asyncio + +from tortoise.contrib.test import tortoise_test_context + + +@pytest_asyncio.fixture(scope="function") +async def db(): + """Function-scoped fixture for isolated tests.""" + async with tortoise_test_context(["examples.pytest.models"]) as ctx: + yield ctx diff --git a/examples/pytest/models.py b/examples/pytest/models.py new file mode 100644 index 000000000..d3d94ead9 --- /dev/null +++ b/examples/pytest/models.py @@ -0,0 +1,10 @@ +from tortoise import fields +from tortoise.models import Model + + +class User(Model): + id = fields.IntField(primary_key=True) + name = fields.CharField(max_length=100) + + class Meta: + table = "user" diff --git a/examples/pytest/test_example.py b/examples/pytest/test_example.py new file mode 100644 index 000000000..d28b566d8 --- /dev/null +++ b/examples/pytest/test_example.py @@ -0,0 +1,23 @@ +import pytest + +from examples.pytest.models import User + + +@pytest.mark.asyncio +async def test_create_user(db): + user = await User.create(name="Alice") + assert user.id is not None + assert user.name == "Alice" + + +@pytest.mark.asyncio +async def test_query_user(db): + await User.create(name="Bob") + users = await User.filter(name="Bob") + assert len(users) == 1 + + +@pytest.mark.asyncio +async def test_isolation(db): + count = await User.all().count() + assert count == 0, "Database should be empty at test start" diff --git a/pyproject.toml b/pyproject.toml index 442849c2a..2f02ef8dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,11 @@ [project] name = "tortoise-orm" description = "Easy async ORM for python, built with relations in mind" -authors = [{name="Andrey Bondar", email="andrey@bondar.ru"}, {name="Nickolas Grigoriadis", email="nagrigoriadis@gmail.com"}, {name="long2ice", email="long2ice@gmail.com"}] -license = {text="Apache-2.0"} +authors = [{ name = "Andrey Bondar", email = "andrey@bondar.ru" }, { name = "Nickolas Grigoriadis", email = "nagrigoriadis@gmail.com" }, { name = "long2ice", email = "long2ice@gmail.com" }] +license = { text = "Apache-2.0" } readme = "README.rst" keywords = ["sql", "mysql", "postgres", "psql", "sqlite", "aiosqlite", "asyncpg", "relational", "database", "rdbms", "orm", "object mapper", "async", "asyncio", "aio", "psycopg"] -dynamic = [ "version" ] +dynamic = ["version"] requires-python = ">=3.10" dependencies = [ "pypika-tortoise (>=0.6.3,<1.0.0)", @@ -74,7 +74,7 @@ dev = [ "types-pytz", "types-PyMySQL", ] -contrib=[ +contrib = [ # Sample integration - Quart "quart", # Sample integration - Sanic @@ -101,6 +101,8 @@ test = [ "pytest-cov", "pytest-codspeed", "pytest-asyncio", + "cryptography>=46.0.3", + "aiomysql>=0.3.2", ] docs = [ # Documentation tools @@ -115,7 +117,7 @@ requires = ["pdm-backend"] build-backend = "pdm.backend" [tool.pdm] -version = {source="file", path="tortoise/__init__.py"} +version = { source = "file", path = "tortoise/__init__.py" } [tool.pdm.build] excludes = ["./**/.git", "./**/.*_cache", "examples"] @@ -188,7 +190,9 @@ per-file-ignores = [ docstring_style = "sphinx" [tool.pytest.ini_options] +asyncio_mode = "strict" asyncio_default_fixture_loop_scope = "session" +asyncio_default_test_loop_scope = "session" filterwarnings = [ 'ignore:`pk` is deprecated:DeprecationWarning', 'ignore:`index` is deprecated:DeprecationWarning', @@ -200,12 +204,13 @@ show_missing = true [tool.ruff] line-length = 100 +exclude = ["docs/type_globals.py"] [tool.ruff.lint] ignore = ["E501"] extend-select = [ - "I", # https://docs.astral.sh/ruff/rules/#isort-i - "FA", # https://docs.astral.sh/ruff/rules/#flake8-future-annotations-fa - "UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up + "I", # https://docs.astral.sh/ruff/rules/#isort-i + "FA", # https://docs.astral.sh/ruff/rules/#flake8-future-annotations-fa + "UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up "RUF100", # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf ] @@ -216,6 +221,7 @@ extra-standard-library = ["tomllib"] exclude_dirs = [ "tests", "examples/*/_tests.py", + "examples/pytest", "conftest.py", "tortoise/migrations/schema_editor/mssql.py", "examples/postgres_full_text_search.py", diff --git a/tests/backends/test_capabilities.py b/tests/backends/test_capabilities.py index 362213bd3..8a2318c4a 100644 --- a/tests/backends/test_capabilities.py +++ b/tests/backends/test_capabilities.py @@ -1,46 +1,76 @@ +import pytest + from tortoise import connections -from tortoise.contrib import test +from tortoise.contrib.test import requireCapability +from tortoise.exceptions import ConfigurationError + + +@pytest.fixture +def db_and_caps(db): + """Get database connection and capabilities.""" + db_conn = connections.get("models") + return db_conn, db_conn.capabilities + + +@pytest.mark.asyncio +async def test_str(db): + """Test capabilities string representation.""" + caps = connections.get("models").capabilities + assert "requires_limit" in str(caps) + + +@pytest.mark.asyncio +async def test_immutability_1(db): + """Test capabilities are immutable.""" + caps = connections.get("models").capabilities + assert isinstance(caps.dialect, str) + with pytest.raises(AttributeError): + caps.dialect = "foo" + +@pytest.mark.xfail(raises=ConfigurationError, reason="Connection 'other' does not exist") +@requireCapability(connection_name="other") +@pytest.mark.asyncio +async def test_connection_name(db): + """Will fail with a ConfigurationError since connection 'other' does not exist.""" + pass -class TestCapabilities(test.TestCase): - # pylint: disable=E1101 - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - self.db = connections.get("models") - self.caps = self.db.capabilities +@requireCapability(dialect="sqlite") +@pytest.mark.xfail(reason="Test is expected to fail - testing xfail behavior") +@pytest.mark.asyncio +async def test_actually_runs(db): + """Test that xfail actually runs.""" + assert False - def test_str(self): - self.assertIn("requires_limit", str(self.caps)) - def test_immutability_1(self): - self.assertIsInstance(self.caps.dialect, str) - with self.assertRaises(AttributeError): - self.caps.dialect = "foo" +@pytest.mark.asyncio +async def test_attribute_error(db): + """Test capabilities raise AttributeError on invalid attribute assignment.""" + caps = connections.get("models").capabilities + with pytest.raises(AttributeError): + caps.bar = "foo" - @test.expectedFailure - @test.requireCapability(connection_name="other") - def test_connection_name(self): - # Will fail with a `KeyError` since the connection `"other"` does not exist. - pass - @test.requireCapability(dialect="sqlite") - @test.expectedFailure - def test_actually_runs(self): - self.assertTrue(False) # pylint: disable=W1503 +@requireCapability(dialect="sqlite") +@pytest.mark.asyncio +async def test_dialect_sqlite(db): + """Test sqlite dialect capability.""" + caps = connections.get("models").capabilities + assert caps.dialect == "sqlite" - def test_attribute_error(self): - with self.assertRaises(AttributeError): - self.caps.bar = "foo" - @test.requireCapability(dialect="sqlite") - def test_dialect_sqlite(self): - self.assertEqual(self.caps.dialect, "sqlite") +@requireCapability(dialect="mysql") +@pytest.mark.asyncio +async def test_dialect_mysql(db): + """Test mysql dialect capability.""" + caps = connections.get("models").capabilities + assert caps.dialect == "mysql" - @test.requireCapability(dialect="mysql") - def test_dialect_mysql(self): - self.assertEqual(self.caps.dialect, "mysql") - @test.requireCapability(dialect="postgres") - def test_dialect_postgres(self): - self.assertEqual(self.caps.dialect, "postgres") +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_dialect_postgres(db): + """Test postgres dialect capability.""" + caps = connections.get("models").capabilities + assert caps.dialect == "postgres" diff --git a/tests/backends/test_connection_params.py b/tests/backends/test_connection_params.py index ef346760a..c8e458d49 100644 --- a/tests/backends/test_connection_params.py +++ b/tests/backends/test_connection_params.py @@ -1,23 +1,19 @@ from unittest.mock import AsyncMock, patch import asyncpg +import pytest -from tortoise import connections -from tortoise.contrib import test +from tortoise.context import TortoiseContext -class TestConnectionParams(test.SimpleTestCase): - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - - async def asyncTearDown(self) -> None: - await super().asyncTearDown() - - async def test_mysql_connection_params(self): - with patch( - "tortoise.backends.mysql.client.mysql.create_pool", new=AsyncMock() - ) as mysql_connect: - await connections._init( +@pytest.mark.asyncio +async def test_mysql_connection_params(): + with patch( + "tortoise.backends.mysql.client.mysql.create_pool", new=AsyncMock() + ) as mysql_connect: + ctx = TortoiseContext() + async with ctx: + await ctx.connections._init( { "models": { "engine": "tortoise.backends.mysql", @@ -34,7 +30,7 @@ async def test_mysql_connection_params(self): }, False, ) - await connections.get("models").create_connection(with_db=True) + await ctx.connections.get("models").create_connection(with_db=True) mysql_connect.assert_awaited_once_with( # nosec autocommit=True, @@ -50,12 +46,16 @@ async def test_mysql_connection_params(self): sql_mode="STRICT_TRANS_TABLES", ) - async def test_asyncpg_connection_params(self): - try: - with patch( - "tortoise.backends.asyncpg.client.asyncpg.create_pool", new=AsyncMock() - ) as asyncpg_connect: - await connections._init( + +@pytest.mark.asyncio +async def test_asyncpg_connection_params(): + try: + with patch( + "tortoise.backends.asyncpg.client.asyncpg.create_pool", new=AsyncMock() + ) as asyncpg_connect: + ctx = TortoiseContext() + async with ctx: + await ctx.connections._init( { "models": { "engine": "tortoise.backends.asyncpg", @@ -72,7 +72,7 @@ async def test_asyncpg_connection_params(self): }, False, ) - await connections.get("models").create_connection(with_db=True) + await ctx.connections.get("models").create_connection(with_db=True) asyncpg_connect.assert_awaited_once_with( # nosec None, @@ -89,17 +89,21 @@ async def test_asyncpg_connection_params(self): loop=None, server_settings={}, ) - except ImportError: - self.skipTest("asyncpg not installed") + except ImportError: + pytest.skip("asyncpg not installed") + - async def test_psycopg_connection_params(self): - try: - with patch( - "tortoise.backends.psycopg.client.PsycopgClient.create_pool", new=AsyncMock() - ) as patched_create_pool: - mocked_pool = AsyncMock() - patched_create_pool.return_value = mocked_pool - await connections._init( +@pytest.mark.asyncio +async def test_psycopg_connection_params(): + try: + with patch( + "tortoise.backends.psycopg.client.PsycopgClient.create_pool", new=AsyncMock() + ) as patched_create_pool: + mocked_pool = AsyncMock() + patched_create_pool.return_value = mocked_pool + ctx = TortoiseContext() + async with ctx: + await ctx.connections._init( { "models": { "engine": "tortoise.backends.psycopg", @@ -116,11 +120,12 @@ async def test_psycopg_connection_params(self): }, False, ) - await connections.get("models").create_connection(with_db=True) + await ctx.connections.get("models").create_connection(with_db=True) + patched_create_pool.assert_awaited_once() mocked_pool.open.assert_awaited_once_with( # nosec wait=True, timeout=1, ) - except ImportError: - self.skipTest("psycopg not installed") + except ImportError: + pytest.skip("psycopg not installed") diff --git a/tests/backends/test_db_url.py b/tests/backends/test_db_url.py index df7dd9d5d..db34d1450 100644 --- a/tests/backends/test_db_url.py +++ b/tests/backends/test_db_url.py @@ -1,421 +1,387 @@ +import pytest + from tortoise.backends.base.config_generator import expand_db_url, generate_config -from tortoise.contrib import test from tortoise.exceptions import ConfigurationError +# These are pure logic tests - no database fixture needed + +_postgres_scheme_engines = { + "postgres": "tortoise.backends.asyncpg", + "asyncpg": "tortoise.backends.asyncpg", + "psycopg": "tortoise.backends.psycopg", +} + + +def test_unknown_scheme(): + with pytest.raises(ConfigurationError): + expand_db_url("moo://baa") + -class TestConfigGenerator(test.SimpleTestCase): - _postgres_scheme_engines = { - "postgres": "tortoise.backends.asyncpg", - "asyncpg": "tortoise.backends.asyncpg", - "psycopg": "tortoise.backends.psycopg", +def test_sqlite_basic(): + res = expand_db_url("sqlite:///some/test.sqlite") + assert res == { + "engine": "tortoise.backends.sqlite", + "credentials": { + "file_path": "/some/test.sqlite", + "journal_mode": "WAL", + "journal_size_limit": 16384, + }, } - def test_unknown_scheme(self): - with self.assertRaises(ConfigurationError): - expand_db_url("moo://baa") - def test_sqlite_basic(self): - res = expand_db_url("sqlite:///some/test.sqlite") - self.assertDictEqual( - res, - { - "engine": "tortoise.backends.sqlite", - "credentials": { - "file_path": "/some/test.sqlite", - "journal_mode": "WAL", - "journal_size_limit": 16384, - }, - }, - ) +def test_sqlite_no_db(): + with pytest.raises(ConfigurationError, match="No path specified for DB_URL"): + expand_db_url("sqlite://") - def test_sqlite_no_db(self): - with self.assertRaisesRegex(ConfigurationError, "No path specified for DB_URL"): - expand_db_url("sqlite://") - def test_sqlite_relative(self): - res = expand_db_url("sqlite://test.sqlite") - self.assertDictEqual( - res, - { - "engine": "tortoise.backends.sqlite", - "credentials": { - "file_path": "test.sqlite", - "journal_mode": "WAL", - "journal_size_limit": 16384, - }, - }, - ) +def test_sqlite_relative(): + res = expand_db_url("sqlite://test.sqlite") + assert res == { + "engine": "tortoise.backends.sqlite", + "credentials": { + "file_path": "test.sqlite", + "journal_mode": "WAL", + "journal_size_limit": 16384, + }, + } - def test_sqlite_relative_with_subdir(self): - res = expand_db_url("sqlite://data/db.sqlite") - self.assertDictEqual( - res, - { - "engine": "tortoise.backends.sqlite", - "credentials": { - "file_path": "data/db.sqlite", - "journal_mode": "WAL", - "journal_size_limit": 16384, - }, - }, - ) - def test_sqlite_testing(self): - res = expand_db_url(db_url="sqlite:///some/test-{}.sqlite", testing=True) - file_path = res["credentials"]["file_path"] - self.assertIn("/some/test-", file_path) - self.assertIn(".sqlite", file_path) - self.assertNotEqual("sqlite:///some/test-{}.sqlite", file_path) - self.assertDictEqual( - res, - { - "engine": "tortoise.backends.sqlite", - "credentials": { - "file_path": file_path, - "journal_mode": "WAL", - "journal_size_limit": 16384, - }, - }, - ) +def test_sqlite_relative_with_subdir(): + res = expand_db_url("sqlite://data/db.sqlite") + assert res == { + "engine": "tortoise.backends.sqlite", + "credentials": { + "file_path": "data/db.sqlite", + "journal_mode": "WAL", + "journal_size_limit": 16384, + }, + } - def test_sqlite_params(self): - res = expand_db_url("sqlite:///some/test.sqlite?AHA=5&moo=yes&journal_mode=TRUNCATE") - self.assertDictEqual( - res, - { - "engine": "tortoise.backends.sqlite", - "credentials": { - "file_path": "/some/test.sqlite", - "AHA": "5", - "moo": "yes", - "journal_mode": "TRUNCATE", - "journal_size_limit": 16384, - }, - }, - ) - def test_sqlite_invalid(self): - with self.assertRaises(ConfigurationError): - expand_db_url("sqlite://") - - def test_postgres_basic(self): - for scheme, engine in self._postgres_scheme_engines.items(): - res = expand_db_url(f"{scheme}://postgres:moo@127.0.0.1:54321/test") - self.assertDictEqual( - res, - { - "engine": engine, - "credentials": { - "database": "test", - "host": "127.0.0.1", - "password": "moo", - "port": 54321, - "user": "postgres", - }, - }, - ) - - def test_postgres_encoded_password(self): - for scheme, engine in self._postgres_scheme_engines.items(): - res = expand_db_url(f"{scheme}://postgres:kx%25jj5%2Fg@127.0.0.1:54321/test") - self.assertDictEqual( - res, - { - "engine": engine, - "credentials": { - "database": "test", - "host": "127.0.0.1", - "password": "kx%jj5/g", - "port": 54321, - "user": "postgres", - }, - }, - ) - - def test_postgres_no_db(self): - for scheme, engine in self._postgres_scheme_engines.items(): - res = expand_db_url(f"{scheme}://postgres:moo@127.0.0.1:54321") - self.assertDictEqual( - res, - { - "engine": engine, - "credentials": { - "database": None, - "host": "127.0.0.1", - "password": "moo", - "port": 54321, - "user": "postgres", - }, - }, - ) - - def test_postgres_no_port(self): - for scheme, engine in self._postgres_scheme_engines.items(): - res = expand_db_url(f"{scheme}://postgres@127.0.0.1/test") - self.assertDictEqual( - res, - { - "engine": engine, - "credentials": { - "database": "test", - "host": "127.0.0.1", - "password": None, - "port": 5432, - "user": "postgres", - }, - }, - ) - - def test_postgres_nonint_port(self): - for scheme in self._postgres_scheme_engines: - with self.assertRaises(ConfigurationError): - expand_db_url(f"{scheme}://postgres:@127.0.0.1:moo/test") - - def test_postgres_testing(self): - for scheme, engine in self._postgres_scheme_engines.items(): - res = expand_db_url( - db_url=(f"{scheme}://postgres@127.0.0.1:5432/" + r"test_\{\}"), testing=True - ) - database = res["credentials"]["database"] - self.assertIn("test_", database) - self.assertNotEqual("test_{}", database) - self.assertDictEqual( - res, - { - "engine": engine, - "credentials": { - "database": database, - "host": "127.0.0.1", - "password": None, - "port": 5432, - "user": "postgres", - }, - }, - ) - - def test_postgres_params(self): - for scheme, engine in self._postgres_scheme_engines.items(): - res = expand_db_url(f"{scheme}://postgres@127.0.0.1:5432/test?AHA=5&moo=yes") - self.assertDictEqual( - res, - { - "engine": engine, - "credentials": { - "database": "test", - "host": "127.0.0.1", - "password": None, - "port": 5432, - "user": "postgres", - "AHA": "5", - "moo": "yes", - }, - }, - ) - - def test_mysql_basic(self): - res = expand_db_url("mysql://root:@127.0.0.1:33060/test") - self.assertEqual( - res, - { - "engine": "tortoise.backends.mysql", - "credentials": { - "database": "test", - "host": "127.0.0.1", - "password": "", - "port": 33060, - "user": "root", - "charset": "utf8mb4", - "sql_mode": "STRICT_TRANS_TABLES", - }, - }, - ) +def test_sqlite_testing(): + res = expand_db_url(db_url="sqlite:///some/test-{}.sqlite", testing=True) + file_path = res["credentials"]["file_path"] + assert "/some/test-" in file_path + assert ".sqlite" in file_path + assert "sqlite:///some/test-{}.sqlite" != file_path + assert res == { + "engine": "tortoise.backends.sqlite", + "credentials": { + "file_path": file_path, + "journal_mode": "WAL", + "journal_size_limit": 16384, + }, + } - def test_mysql_encoded_password(self): - res = expand_db_url("mysql://root:kx%25jj5%2Fg@127.0.0.1:33060/test") - self.assertEqual( - res, - { - "engine": "tortoise.backends.mysql", - "credentials": { - "database": "test", - "host": "127.0.0.1", - "password": "kx%jj5/g", - "port": 33060, - "user": "root", - "charset": "utf8mb4", - "sql_mode": "STRICT_TRANS_TABLES", - }, + +def test_sqlite_params(): + res = expand_db_url("sqlite:///some/test.sqlite?AHA=5&moo=yes&journal_mode=TRUNCATE") + assert res == { + "engine": "tortoise.backends.sqlite", + "credentials": { + "file_path": "/some/test.sqlite", + "AHA": "5", + "moo": "yes", + "journal_mode": "TRUNCATE", + "journal_size_limit": 16384, + }, + } + + +def test_sqlite_invalid(): + with pytest.raises(ConfigurationError): + expand_db_url("sqlite://") + + +def test_postgres_basic(): + for scheme, engine in _postgres_scheme_engines.items(): + res = expand_db_url(f"{scheme}://postgres:moo@127.0.0.1:54321/test") + assert res == { + "engine": engine, + "credentials": { + "database": "test", + "host": "127.0.0.1", + "password": "moo", + "port": 54321, + "user": "postgres", }, - ) + } - def test_mysql_no_db(self): - res = expand_db_url("mysql://root:@127.0.0.1:33060") - self.assertEqual( - res, - { - "engine": "tortoise.backends.mysql", - "credentials": { - "database": None, - "host": "127.0.0.1", - "password": "", - "port": 33060, - "user": "root", - "charset": "utf8mb4", - "sql_mode": "STRICT_TRANS_TABLES", - }, + +def test_postgres_encoded_password(): + for scheme, engine in _postgres_scheme_engines.items(): + res = expand_db_url(f"{scheme}://postgres:kx%25jj5%2Fg@127.0.0.1:54321/test") + assert res == { + "engine": engine, + "credentials": { + "database": "test", + "host": "127.0.0.1", + "password": "kx%jj5/g", + "port": 54321, + "user": "postgres", }, - ) + } - def test_mysql_no_port(self): - res = expand_db_url("mysql://root@127.0.0.1/test") - self.assertEqual( - res, - { - "engine": "tortoise.backends.mysql", - "credentials": { - "database": "test", - "host": "127.0.0.1", - "password": "", - "port": 3306, - "user": "root", - "charset": "utf8mb4", - "sql_mode": "STRICT_TRANS_TABLES", - }, + +def test_postgres_no_db(): + for scheme, engine in _postgres_scheme_engines.items(): + res = expand_db_url(f"{scheme}://postgres:moo@127.0.0.1:54321") + assert res == { + "engine": engine, + "credentials": { + "database": None, + "host": "127.0.0.1", + "password": "moo", + "port": 54321, + "user": "postgres", }, - ) + } - def test_mysql_nonint_port(self): - with self.assertRaises(ConfigurationError): - expand_db_url("mysql://root:@127.0.0.1:moo/test") - - def test_mysql_testing(self): - res = expand_db_url(r"mysql://root:@127.0.0.1:3306/test_\{\}", testing=True) - self.assertIn("test_", res["credentials"]["database"]) - self.assertNotEqual("test_{}", res["credentials"]["database"]) - self.assertEqual( - res, - { - "engine": "tortoise.backends.mysql", - "credentials": { - "database": res["credentials"]["database"], - "host": "127.0.0.1", - "password": "", - "port": 3306, - "user": "root", - "charset": "utf8mb4", - "sql_mode": "STRICT_TRANS_TABLES", - }, + +def test_postgres_no_port(): + for scheme, engine in _postgres_scheme_engines.items(): + res = expand_db_url(f"{scheme}://postgres@127.0.0.1/test") + assert res == { + "engine": engine, + "credentials": { + "database": "test", + "host": "127.0.0.1", + "password": None, + "port": 5432, + "user": "postgres", }, - ) + } + - def test_mysql_params(self): +def test_postgres_nonint_port(): + for scheme in _postgres_scheme_engines: + with pytest.raises(ConfigurationError): + expand_db_url(f"{scheme}://postgres:@127.0.0.1:moo/test") + + +def test_postgres_testing(): + for scheme, engine in _postgres_scheme_engines.items(): res = expand_db_url( - "mysql://root:@127.0.0.1:3306/test?AHA=5&moo=yes&maxsize=20&minsize=5" - "&connect_timeout=1.5&echo=1&ssl=True" + db_url=(f"{scheme}://postgres@127.0.0.1:5432/" + r"test_\{\}"), testing=True ) - self.assertEqual( - res, - { - "engine": "tortoise.backends.mysql", - "credentials": { - "database": "test", - "host": "127.0.0.1", - "password": "", - "port": 3306, - "user": "root", - "AHA": "5", - "moo": "yes", - "minsize": 5, - "maxsize": 20, - "connect_timeout": 1.5, - "echo": True, - "charset": "utf8mb4", - "sql_mode": "STRICT_TRANS_TABLES", - "ssl": True, - }, + database = res["credentials"]["database"] + assert "test_" in database + assert "test_{}" != database + assert res == { + "engine": engine, + "credentials": { + "database": database, + "host": "127.0.0.1", + "password": None, + "port": 5432, + "user": "postgres", }, - ) + } - def test_generate_config_basic(self): - res = generate_config( - db_url="sqlite:///some/test.sqlite", - app_modules={"models": ["one.models", "two.models"]}, - ) - self.assertEqual( - res, - { - "connections": { - "default": { - "credentials": { - "file_path": "/some/test.sqlite", - "journal_mode": "WAL", - "journal_size_limit": 16384, - }, - "engine": "tortoise.backends.sqlite", - } - }, - "apps": { - "models": { - "models": ["one.models", "two.models"], - "default_connection": "default", - } - }, + +def test_postgres_params(): + for scheme, engine in _postgres_scheme_engines.items(): + res = expand_db_url(f"{scheme}://postgres@127.0.0.1:5432/test?AHA=5&moo=yes") + assert res == { + "engine": engine, + "credentials": { + "database": "test", + "host": "127.0.0.1", + "password": None, + "port": 5432, + "user": "postgres", + "AHA": "5", + "moo": "yes", }, - ) + } - def test_generate_config_explicit(self): - res = generate_config( - db_url="sqlite:///some/test.sqlite", - app_modules={"models": ["one.models", "two.models"]}, - connection_label="models", - testing=True, - ) - self.assertEqual( - res, - { - "connections": { - "models": { - "credentials": { - "file_path": "/some/test.sqlite", - "journal_mode": "WAL", - "journal_size_limit": 16384, - }, - "engine": "tortoise.backends.sqlite", - } - }, - "apps": { - "models": { - "models": ["one.models", "two.models"], - "default_connection": "models", - } + +def test_mysql_basic(): + res = expand_db_url("mysql://root:@127.0.0.1:33060/test") + assert res == { + "engine": "tortoise.backends.mysql", + "credentials": { + "database": "test", + "host": "127.0.0.1", + "password": "", + "port": 33060, + "user": "root", + "charset": "utf8mb4", + "sql_mode": "STRICT_TRANS_TABLES", + }, + } + + +def test_mysql_encoded_password(): + res = expand_db_url("mysql://root:kx%25jj5%2Fg@127.0.0.1:33060/test") + assert res == { + "engine": "tortoise.backends.mysql", + "credentials": { + "database": "test", + "host": "127.0.0.1", + "password": "kx%jj5/g", + "port": 33060, + "user": "root", + "charset": "utf8mb4", + "sql_mode": "STRICT_TRANS_TABLES", + }, + } + + +def test_mysql_no_db(): + res = expand_db_url("mysql://root:@127.0.0.1:33060") + assert res == { + "engine": "tortoise.backends.mysql", + "credentials": { + "database": None, + "host": "127.0.0.1", + "password": "", + "port": 33060, + "user": "root", + "charset": "utf8mb4", + "sql_mode": "STRICT_TRANS_TABLES", + }, + } + + +def test_mysql_no_port(): + res = expand_db_url("mysql://root@127.0.0.1/test") + assert res == { + "engine": "tortoise.backends.mysql", + "credentials": { + "database": "test", + "host": "127.0.0.1", + "password": "", + "port": 3306, + "user": "root", + "charset": "utf8mb4", + "sql_mode": "STRICT_TRANS_TABLES", + }, + } + + +def test_mysql_nonint_port(): + with pytest.raises(ConfigurationError): + expand_db_url("mysql://root:@127.0.0.1:moo/test") + + +def test_mysql_testing(): + res = expand_db_url(r"mysql://root:@127.0.0.1:3306/test_\{\}", testing=True) + assert "test_" in res["credentials"]["database"] + assert "test_{}" != res["credentials"]["database"] + assert res == { + "engine": "tortoise.backends.mysql", + "credentials": { + "database": res["credentials"]["database"], + "host": "127.0.0.1", + "password": "", + "port": 3306, + "user": "root", + "charset": "utf8mb4", + "sql_mode": "STRICT_TRANS_TABLES", + }, + } + + +def test_mysql_params(): + res = expand_db_url( + "mysql://root:@127.0.0.1:3306/test?AHA=5&moo=yes&maxsize=20&minsize=5" + "&connect_timeout=1.5&echo=1&ssl=True" + ) + assert res == { + "engine": "tortoise.backends.mysql", + "credentials": { + "database": "test", + "host": "127.0.0.1", + "password": "", + "port": 3306, + "user": "root", + "AHA": "5", + "moo": "yes", + "minsize": 5, + "maxsize": 20, + "connect_timeout": 1.5, + "echo": True, + "charset": "utf8mb4", + "sql_mode": "STRICT_TRANS_TABLES", + "ssl": True, + }, + } + + +def test_generate_config_basic(): + res = generate_config( + db_url="sqlite:///some/test.sqlite", + app_modules={"models": ["one.models", "two.models"]}, + ) + assert res == { + "connections": { + "default": { + "credentials": { + "file_path": "/some/test.sqlite", + "journal_mode": "WAL", + "journal_size_limit": 16384, }, - }, - ) + "engine": "tortoise.backends.sqlite", + } + }, + "apps": { + "models": { + "models": ["one.models", "two.models"], + "default_connection": "default", + } + }, + } - def test_generate_config_many_apps(self): - res = generate_config( - db_url="sqlite:///some/test.sqlite", - app_modules={"models": ["one.models", "two.models"], "peanuts": ["peanut.models"]}, - ) - self.assertEqual( - res, - { - "connections": { - "default": { - "credentials": { - "file_path": "/some/test.sqlite", - "journal_mode": "WAL", - "journal_size_limit": 16384, - }, - "engine": "tortoise.backends.sqlite", - } + +def test_generate_config_explicit(): + res = generate_config( + db_url="sqlite:///some/test.sqlite", + app_modules={"models": ["one.models", "two.models"]}, + connection_label="models", + testing=True, + ) + assert res == { + "connections": { + "models": { + "credentials": { + "file_path": "/some/test.sqlite", + "journal_mode": "WAL", + "journal_size_limit": 16384, }, - "apps": { - "models": { - "models": ["one.models", "two.models"], - "default_connection": "default", - }, - "peanuts": {"models": ["peanut.models"], "default_connection": "default"}, + "engine": "tortoise.backends.sqlite", + } + }, + "apps": { + "models": { + "models": ["one.models", "two.models"], + "default_connection": "models", + } + }, + } + + +def test_generate_config_many_apps(): + res = generate_config( + db_url="sqlite:///some/test.sqlite", + app_modules={"models": ["one.models", "two.models"], "peanuts": ["peanut.models"]}, + ) + assert res == { + "connections": { + "default": { + "credentials": { + "file_path": "/some/test.sqlite", + "journal_mode": "WAL", + "journal_size_limit": 16384, }, + "engine": "tortoise.backends.sqlite", + } + }, + "apps": { + "models": { + "models": ["one.models", "two.models"], + "default_connection": "default", }, - ) + "peanuts": {"models": ["peanut.models"], "default_connection": "default"}, + }, + } diff --git a/tests/backends/test_explain.py b/tests/backends/test_explain.py index 3680d5888..e16d91c91 100644 --- a/tests/backends/test_explain.py +++ b/tests/backends/test_explain.py @@ -1,15 +1,20 @@ +import pytest + from tests.testmodels import Tournament -from tortoise.contrib import test +from tortoise.contrib.test import requireCapability from tortoise.contrib.test.condition import NotEQ -class TestExplain(test.TestCase): - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_explain(self): - # NOTE: we do not provide any guarantee on the format of the value - # returned by `.explain()`, as it heavily depends on the database. - # This test merely checks that one is able to run `.explain()` - # without errors for each backend. - plan = await Tournament.all().explain() - # This should have returned *some* information. - self.assertGreater(len(str(plan)), 20) +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_explain(db): + """Test that explain() returns query plan information. + + NOTE: we do not provide any guarantee on the format of the value + returned by `.explain()`, as it heavily depends on the database. + This test merely checks that one is able to run `.explain()` + without errors for each backend. + """ + plan = await Tournament.all().explain() + # This should have returned *some* information. + assert len(str(plan)) > 20 diff --git a/tests/backends/test_mysql.py b/tests/backends/test_mysql.py index b95533e4b..37f7c1cf4 100644 --- a/tests/backends/test_mysql.py +++ b/tests/backends/test_mysql.py @@ -2,51 +2,88 @@ Test some mysql-specific features """ +import copy +import os import ssl -from tortoise import Tortoise -from tortoise.contrib import test +import pytest +from tortoise.backends.base.config_generator import generate_config +from tortoise.context import TortoiseContext -class TestMySQL(test.SimpleTestCase): - async def asyncSetUp(self): - if Tortoise._inited: - await self._tearDownDB() - self.db_config = test.getDBConfig(app_label="models", modules=["tests.testmodels"]) - if self.db_config["connections"]["models"]["engine"] != "tortoise.backends.mysql": - raise test.SkipTest("MySQL only") - async def asyncTearDown(self) -> None: - if Tortoise._inited: - await Tortoise._drop_databases() - await super().asyncTearDown() +def _get_db_config(): + """Get database config and check if it's MySQL.""" + db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") + db_config = generate_config( + db_url, + app_modules={"models": ["tests.testmodels"]}, + connection_label="models", + testing=True, + ) + engine = db_config["connections"]["models"]["engine"] + is_mysql = engine == "tortoise.backends.mysql" + return db_config, is_mysql - async def test_bad_charset(self): - self.db_config["connections"]["models"]["credentials"]["charset"] = "terrible" - with self.assertRaisesRegex(ConnectionError, "Unknown charset"): - await Tortoise.init(self.db_config, _create_db=True) - async def test_ssl_true(self): - self.db_config["connections"]["models"]["credentials"]["ssl"] = True - try: - import asyncmy # noqa pylint: disable=unused-import +@pytest.mark.asyncio +async def test_bad_charset(): + """Test that invalid charset raises ConnectionError.""" + base_config, is_mysql = _get_db_config() + if not is_mysql: + pytest.skip("MySQL only") - # setting read_timeout for asyncmy. Otherwise, it will hang forever. - self.db_config["connections"]["models"]["credentials"]["read_timeout"] = 1 - except ImportError: - pass + # Deep copy to avoid modifying shared config + db_config = copy.deepcopy(base_config) + db_config["connections"]["models"]["credentials"]["charset"] = "terrible" + + async with TortoiseContext() as ctx: + with pytest.raises(ConnectionError, match="Unknown charset"): + await ctx.init(db_config, _create_db=True) + + +@pytest.mark.asyncio +async def test_ssl_true(): + """Test that SSL=True with no cert raises ConnectionError.""" + base_config, is_mysql = _get_db_config() + if not is_mysql: + pytest.skip("MySQL only") + + # Deep copy to avoid modifying shared config + db_config = copy.deepcopy(base_config) + db_config["connections"]["models"]["credentials"]["ssl"] = True + try: + import asyncmy # noqa pylint: disable=unused-import + + # setting read_timeout for asyncmy. Otherwise, it will hang forever. + db_config["connections"]["models"]["credentials"]["read_timeout"] = 1 + except ImportError: + pass + + async with TortoiseContext() as ctx: + with pytest.raises(ConnectionError): + await ctx.init(db_config, _create_db=True) + + +@pytest.mark.asyncio +async def test_ssl_custom(): + """Test SSL with custom context (may pass or fail depending on server).""" + base_config, is_mysql = _get_db_config() + if not is_mysql: + pytest.skip("MySQL only") + + # Deep copy to avoid modifying shared config + db_config = copy.deepcopy(base_config) - with self.assertRaises(ConnectionError): - await Tortoise.init(self.db_config, _create_db=True) + # Expect connectionerror or pass + ssl_ctx = ssl.create_default_context() + ssl_ctx.check_hostname = False + ssl_ctx.verify_mode = ssl.CERT_NONE - async def test_ssl_custom(self): - # Expect connectionerror or pass - ctx = ssl.create_default_context() - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE + db_config["connections"]["models"]["credentials"]["ssl"] = ssl_ctx - self.db_config["connections"]["models"]["credentials"]["ssl"] = ctx + async with TortoiseContext() as ctx: try: - await Tortoise.init(self.db_config, _create_db=True) + await ctx.init(db_config, _create_db=True) except ConnectionError: pass diff --git a/tests/backends/test_postgres.py b/tests/backends/test_postgres.py index 98e3a1c2e..694469cd1 100644 --- a/tests/backends/test_postgres.py +++ b/tests/backends/test_postgres.py @@ -2,46 +2,51 @@ Test some PostgreSQL-specific features """ +import os import ssl +import pytest + from tests.testmodels import Tournament from tortoise import Tortoise, connections -from tortoise.contrib import test +from tortoise.backends.base.config_generator import generate_config from tortoise.exceptions import OperationalError -class TestPostgreSQL(test.SimpleTestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - if Tortoise._inited: - await self._tearDownDB() - self.db_config = test.getDBConfig(app_label="models", modules=["tests.testmodels"]) - if not self.is_asyncpg and not self.is_psycopg: - raise test.SkipTest("PostgreSQL only") - - @property - def is_psycopg(self) -> bool: - return self.db_config["connections"]["models"]["engine"] == "tortoise.backends.psycopg" - - @property - def is_asyncpg(self) -> bool: - return self.db_config["connections"]["models"]["engine"] == "tortoise.backends.asyncpg" - - async def asyncTearDown(self) -> None: - if Tortoise._inited: - await Tortoise._drop_databases() - await super().asyncTearDown() - - async def test_schema(self): - if self.is_asyncpg: - from asyncpg.exceptions import InvalidSchemaNameError - else: - from psycopg.errors import InvalidSchemaName as InvalidSchemaNameError - - self.db_config["connections"]["models"]["credentials"]["schema"] = "mytestschema" - await Tortoise.init(self.db_config, _create_db=True) - - with self.assertRaises(InvalidSchemaNameError): +def _get_db_config(): + """Get database config and check if it's PostgreSQL.""" + db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") + db_config = generate_config( + db_url, + app_modules={"models": ["tests.testmodels"]}, + connection_label="models", + testing=True, + ) + engine = db_config["connections"]["models"]["engine"] + is_asyncpg = engine == "tortoise.backends.asyncpg" + is_psycopg = engine == "tortoise.backends.psycopg" + return db_config, is_asyncpg, is_psycopg + + +@pytest.mark.asyncio +async def test_schema(db_simple): + db_config, is_asyncpg, is_psycopg = _get_db_config() + if not is_asyncpg and not is_psycopg: + pytest.skip("PostgreSQL only") + + if is_asyncpg: + from asyncpg.exceptions import InvalidSchemaNameError + else: + from psycopg.errors import InvalidSchemaName as InvalidSchemaNameError + + if Tortoise._inited: + await Tortoise._drop_databases() + + try: + db_config["connections"]["models"]["credentials"]["schema"] = "mytestschema" + await Tortoise.init(db_config, _create_db=True) + + with pytest.raises(InvalidSchemaNameError): await Tortoise.generate_schemas() conn = connections.get("models") @@ -51,10 +56,10 @@ async def test_schema(self): tournament = await Tournament.create(name="Test") await connections.close_all() - del self.db_config["connections"]["models"]["credentials"]["schema"] - await Tortoise.init(self.db_config) + del db_config["connections"]["models"]["credentials"]["schema"] + await Tortoise.init(db_config) - with self.assertRaises(OperationalError): + with pytest.raises(OperationalError): await Tournament.filter(name="Test").first() conn = connections.get("models") @@ -62,41 +67,74 @@ async def test_schema(self): "SELECT id, name FROM mytestschema.tournament WHERE name='Test' LIMIT 1" ) - self.assertEqual(len(res), 1) - self.assertEqual(tournament.id, res[0]["id"]) - self.assertEqual(tournament.name, res[0]["name"]) - - async def test_ssl_true(self): - self.db_config["connections"]["models"]["credentials"]["ssl"] = True - try: - await Tortoise.init(self.db_config, _create_db=True) - except (ConnectionError, ssl.SSLError): - pass - else: - self.assertFalse(True, "Expected ConnectionError or SSLError") - - async def test_ssl_custom(self): - # Expect connectionerror or pass - ctx = ssl.create_default_context() - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE - - self.db_config["connections"]["models"]["credentials"]["ssl"] = ctx - try: - await Tortoise.init(self.db_config, _create_db=True) - except ConnectionError: - pass - - async def test_application_name(self): - self.db_config["connections"]["models"]["credentials"]["application_name"] = ( - "mytest_application" - ) - await Tortoise.init(self.db_config, _create_db=True) + assert len(res) == 1 + assert tournament.id == res[0]["id"] + assert tournament.name == res[0]["name"] + finally: + if Tortoise._inited: + await Tortoise._drop_databases() + + +@pytest.mark.asyncio +async def test_ssl_true(): + db_config, is_asyncpg, is_psycopg = _get_db_config() + if not is_asyncpg and not is_psycopg: + pytest.skip("PostgreSQL only") + + db_config["connections"]["models"]["credentials"]["ssl"] = True + ssl_failed = False + try: + await Tortoise.init(db_config, _create_db=True) + except (ConnectionError, ssl.SSLError): + ssl_failed = True + else: + assert False, "Expected ConnectionError or SSLError" + finally: + # Don't try to drop database if SSL connection failed - we can't connect + if Tortoise._inited and not ssl_failed: + await Tortoise._drop_databases() + + +@pytest.mark.asyncio +async def test_ssl_custom(): + db_config, is_asyncpg, is_psycopg = _get_db_config() + if not is_asyncpg and not is_psycopg: + pytest.skip("PostgreSQL only") + + # Expect connectionerror or pass + ssl_ctx = ssl.create_default_context() + ssl_ctx.check_hostname = False + ssl_ctx.verify_mode = ssl.CERT_NONE + + db_config["connections"]["models"]["credentials"]["ssl"] = ssl_ctx + ssl_failed = False + try: + await Tortoise.init(db_config, _create_db=True) + except ConnectionError: + ssl_failed = True + finally: + # Don't try to drop database if SSL connection failed - we can't connect + if Tortoise._inited and not ssl_failed: + await Tortoise._drop_databases() + + +@pytest.mark.asyncio +async def test_application_name(): + db_config, is_asyncpg, is_psycopg = _get_db_config() + if not is_asyncpg and not is_psycopg: + pytest.skip("PostgreSQL only") + + db_config["connections"]["models"]["credentials"]["application_name"] = "mytest_application" + try: + await Tortoise.init(db_config, _create_db=True) conn = connections.get("models") _, res = await conn.execute_query( "SELECT application_name FROM pg_stat_activity WHERE pid = pg_backend_pid()" ) - self.assertEqual(len(res), 1) - self.assertEqual("mytest_application", res[0]["application_name"]) + assert len(res) == 1 + assert "mytest_application" == res[0]["application_name"] + finally: + if Tortoise._inited: + await Tortoise._drop_databases() diff --git a/tests/backends/test_reconnect.py b/tests/backends/test_reconnect.py index b0e07ef6b..3b546bfb7 100644 --- a/tests/backends/test_reconnect.py +++ b/tests/backends/test_reconnect.py @@ -1,37 +1,41 @@ +import pytest + from tests.testmodels import Tournament from tortoise import connections -from tortoise.contrib import test +from tortoise.contrib.test import requireCapability from tortoise.transactions import in_transaction -@test.requireCapability(daemon=True) -class TestReconnect(test.IsolatedTestCase): - async def test_reconnect(self): - await Tournament.create(name="1") +@requireCapability(daemon=True) +@pytest.mark.asyncio +async def test_reconnect(db_isolated): + """Test reconnection after connection expiry.""" + await Tournament.create(name="1") - await connections.get("models")._expire_connections() + await connections.get("models")._expire_connections() - await Tournament.create(name="2") + await Tournament.create(name="2") - await connections.get("models")._expire_connections() + await connections.get("models")._expire_connections() - await Tournament.create(name="3") + await Tournament.create(name="3") - self.assertEqual( - [f"{a.id}:{a.name}" for a in await Tournament.all()], ["1:1", "2:2", "3:3"] - ) + assert [f"{a.id}:{a.name}" for a in await Tournament.all()] == ["1:1", "2:2", "3:3"] - @test.requireCapability(supports_transactions=True) - async def test_reconnect_transaction_start(self): - async with in_transaction(): - await Tournament.create(name="1") - await connections.get("models")._expire_connections() +@requireCapability(daemon=True, supports_transactions=True) +@pytest.mark.asyncio +async def test_reconnect_transaction_start(db_isolated): + """Test reconnection at transaction start.""" + async with in_transaction(): + await Tournament.create(name="1") + + await connections.get("models")._expire_connections() - async with in_transaction(): - await Tournament.create(name="2") + async with in_transaction(): + await Tournament.create(name="2") - await connections.get("models")._expire_connections() + await connections.get("models")._expire_connections() - async with in_transaction(): - self.assertEqual([f"{a.id}:{a.name}" for a in await Tournament.all()], ["1:1", "2:2"]) + async with in_transaction(): + assert [f"{a.id}:{a.name}" for a in await Tournament.all()] == ["1:1", "2:2"] diff --git a/tests/benchmarks/conftest.py b/tests/benchmarks/conftest.py index 64f9abaca..7c92e99b1 100644 --- a/tests/benchmarks/conftest.py +++ b/tests/benchmarks/conftest.py @@ -14,12 +14,12 @@ Team, Tournament, ) -from tortoise.contrib.test import _restore_default, truncate_all_models +from tortoise.contrib.test import truncate_all_models @pytest.fixture(scope="function", autouse=True) -def setup_database(): - _restore_default() +def setup_database(db): + """Cleanup fixture that depends on db to ensure context is active.""" yield asyncio.get_event_loop().run_until_complete(truncate_all_models()) @@ -31,7 +31,7 @@ def skip_if_codspeed_not_enabled(request): @pytest.fixture -def few_fields_benchmark_dataset() -> list[BenchmarkFewFields]: +def few_fields_benchmark_dataset(db) -> list[BenchmarkFewFields]: async def _create() -> list[BenchmarkFewFields]: res = [] for _ in range(100): @@ -43,7 +43,7 @@ async def _create() -> list[BenchmarkFewFields]: @pytest.fixture -def many_fields_benchmark_dataset(gen_many_fields_data) -> list[BenchmarkManyFields]: +def many_fields_benchmark_dataset(db, gen_many_fields_data) -> list[BenchmarkManyFields]: async def _create() -> list[BenchmarkManyFields]: res = [] for _ in range(100): @@ -97,7 +97,7 @@ def _gen(): @pytest.fixture -def create_team_with_participants() -> None: +def create_team_with_participants(db) -> None: async def _create() -> None: tournament = await Tournament.create(name="New Tournament") event = await Event.create(name="Test", tournament_id=tournament.id) @@ -108,7 +108,7 @@ async def _create() -> None: @pytest.fixture -def create_decimals() -> None: +def create_decimals(db) -> None: async def _create() -> None: await DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7) diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 91e42da06..1a4c9c8b7 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -315,7 +315,7 @@ async def fake_applied(self) -> list[MigrationKey]: monkeypatch.setattr(cli_module.Tortoise, "init", fake_init) monkeypatch.setattr(cli_module.MigrationRecorder, "applied_migrations", fake_applied) - monkeypatch.setattr(cli_module.connections, "get", lambda _name: object()) + monkeypatch.setattr(cli_module, "get_connection", lambda _name: object()) result = await _run_cli(["-c", f"{module_name}.TORTOISE_ORM", "history"]) assert result.exit_code == 0 diff --git a/tests/contrib/mysql/fields.py b/tests/contrib/mysql/fields.py index 8426e96eb..47a91cdb5 100644 --- a/tests/contrib/mysql/fields.py +++ b/tests/contrib/mysql/fields.py @@ -1,46 +1,53 @@ import uuid +import pytest + from tests import testmodels_mysql -from tortoise.contrib import test from tortoise.exceptions import IntegrityError -class TestMySQLUUIDFields(test.TestCase): - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels_mysql.UUIDFields.create() - - async def test_create(self): - data = uuid.uuid4() - obj0 = await testmodels_mysql.UUIDFields.create(data=data) - self.assertIsInstance(obj0.data, bytes) - self.assertIsInstance(obj0.data_auto, bytes) - self.assertEqual(obj0.data_null, None) - obj = await testmodels_mysql.UUIDFields.get(id=obj0.id) - self.assertIsInstance(obj.data, uuid.UUID) - self.assertIsInstance(obj.data_auto, uuid.UUID) - self.assertEqual(obj.data, data) - self.assertEqual(obj.data_null, None) - await obj.save() - obj2 = await testmodels_mysql.UUIDFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - await obj.delete() - obj = await testmodels_mysql.UUIDFields.filter(id=obj0.id).first() - self.assertEqual(obj, None) - - async def test_update(self): - data = uuid.uuid4() - data2 = uuid.uuid4() - obj0 = await testmodels_mysql.UUIDFields.create(data=data) - await testmodels_mysql.UUIDFields.filter(id=obj0.id).update(data=data2) - obj = await testmodels_mysql.UUIDFields.get(id=obj0.id) - self.assertEqual(obj.data, data2) - self.assertEqual(obj.data_null, None) - - async def test_create_not_null(self): - data = uuid.uuid4() - obj0 = await testmodels_mysql.UUIDFields.create(data=data, data_null=data) - obj = await testmodels_mysql.UUIDFields.get(id=obj0.id) - self.assertEqual(obj.data, data) - self.assertEqual(obj.data_null, data) +@pytest.mark.asyncio +async def test_empty(db): + with pytest.raises(IntegrityError): + await testmodels_mysql.UUIDFields.create() + + +@pytest.mark.asyncio +async def test_create(db): + data = uuid.uuid4() + obj0 = await testmodels_mysql.UUIDFields.create(data=data) + assert isinstance(obj0.data, bytes) + assert isinstance(obj0.data_auto, bytes) + assert obj0.data_null is None + obj = await testmodels_mysql.UUIDFields.get(id=obj0.id) + assert isinstance(obj.data, uuid.UUID) + assert isinstance(obj.data_auto, uuid.UUID) + assert obj.data == data + assert obj.data_null is None + await obj.save() + obj2 = await testmodels_mysql.UUIDFields.get(id=obj.id) + assert obj == obj2 + + await obj.delete() + obj = await testmodels_mysql.UUIDFields.filter(id=obj0.id).first() + assert obj is None + + +@pytest.mark.asyncio +async def test_update(db): + data = uuid.uuid4() + data2 = uuid.uuid4() + obj0 = await testmodels_mysql.UUIDFields.create(data=data) + await testmodels_mysql.UUIDFields.filter(id=obj0.id).update(data=data2) + obj = await testmodels_mysql.UUIDFields.get(id=obj0.id) + assert obj.data == data2 + assert obj.data_null is None + + +@pytest.mark.asyncio +async def test_create_not_null(db): + data = uuid.uuid4() + obj0 = await testmodels_mysql.UUIDFields.create(data=data, data_null=data) + obj = await testmodels_mysql.UUIDFields.get(id=obj0.id) + assert obj.data == data + assert obj.data_null == data diff --git a/tests/contrib/postgres/conftest.py b/tests/contrib/postgres/conftest.py new file mode 100644 index 000000000..ca1ac8ca7 --- /dev/null +++ b/tests/contrib/postgres/conftest.py @@ -0,0 +1,98 @@ +""" +Custom fixtures for PostgreSQL-specific tests that require specific model modules. + +These fixtures support tests that define tortoise_test_modules to use +custom model definitions for PostgreSQL features like TSVector. +""" + +import os + +import pytest +import pytest_asyncio + +from tortoise.context import tortoise_test_context + + +def skip_if_not_postgres(): + """Skip test if not running against PostgreSQL.""" + db_url = os.getenv("TORTOISE_TEST_DB", "") + if db_url.split(":", 1)[0] not in {"postgres", "asyncpg", "psycopg"}: + pytest.skip("Postgres-only test.") + + +@pytest_asyncio.fixture(scope="module") +async def db_module_postgres(): + """ + Module-scoped fixture for postgres tests using standard testmodels. + + Creates a TortoiseContext with tests.testmodels once per test module. + Used as base for postgres tests that need standard models like TextFields. + """ + skip_if_not_postgres() + db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") + async with tortoise_test_context( + modules=["tests.testmodels"], + db_url=db_url, + app_label="models", + connection_label="models", + ) as ctx: + yield ctx + + +@pytest_asyncio.fixture(scope="function") +async def db_postgres(db_module_postgres): + """ + Function-scoped fixture with transaction rollback for postgres tests. + + Equivalent to: test.TestCase with standard testmodels. + """ + conn = db_module_postgres.db() + transaction = conn._in_transaction() + await transaction.__aenter__() + + try: + yield db_module_postgres + finally: + + class _RollbackException(Exception): + pass + + await transaction.__aexit__(_RollbackException, _RollbackException(), None) + + +@pytest_asyncio.fixture(scope="function") +async def db_tsvector(): + """ + Fixture for TestTSVectorField. + + Uses models defined in tests.contrib.postgres.models_tsvector module. + Equivalent to: test.IsolatedTestCase with tortoise_test_modules + """ + skip_if_not_postgres() + db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") + async with tortoise_test_context( + modules=["tests.contrib.postgres.models_tsvector"], + db_url=db_url, + app_label="models", + connection_label="models", + ) as ctx: + yield ctx + + +@pytest_asyncio.fixture(scope="function") +async def db_search(): + """ + Fixture for TestPostgresSearchLookupTSVector. + + Uses models defined in tests.contrib.postgres.models_tsvector module. + Equivalent to: test.IsolatedTestCase with tortoise_test_modules + """ + skip_if_not_postgres() + db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") + async with tortoise_test_context( + modules=["tests.contrib.postgres.models_tsvector"], + db_url=db_url, + app_label="models", + connection_label="models", + ) as ctx: + yield ctx diff --git a/tests/contrib/postgres/test_json.py b/tests/contrib/postgres/test_json.py index a6d0ab96d..93acef982 100644 --- a/tests/contrib/postgres/test_json.py +++ b/tests/contrib/postgres/test_json.py @@ -1,140 +1,167 @@ from datetime import datetime from decimal import Decimal +import pytest +import pytest_asyncio + from tests.testmodels import JSONFields from tortoise.contrib import test from tortoise.exceptions import DoesNotExist -@test.requireCapability(dialect="postgres") -class TestPostgresJSON(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - self.obj = await JSONFields.create( - data={ - "val": "word1", - "int_val": 123, - "float_val": 123.1, +async def get_by_data_filter(obj, **kwargs) -> JSONFields: + return await JSONFields.get(data__filter=kwargs) + + +@pytest_asyncio.fixture +async def json_obj(db_postgres): + """Create test object with JSON data for postgres tests.""" + obj = await JSONFields.create( + data={ + "val": "word1", + "int_val": 123, + "float_val": 123.1, + "date_val": datetime(1970, 1, 1, 12, 36, 59, 123456), + "int_list": [1, 2, 3], + "nested": { + "val": "word2", + "int_val": 456, + "int_list": [4, 5, 6], "date_val": datetime(1970, 1, 1, 12, 36, 59, 123456), - "int_list": [1, 2, 3], "nested": { - "val": "word2", - "int_val": 456, - "int_list": [4, 5, 6], - "date_val": datetime(1970, 1, 1, 12, 36, 59, 123456), - "nested": { - "val": "word3", - }, + "val": "word3", }, - } - ) + }, + } + ) + return obj - async def get_by_data_filter(self, **kwargs) -> JSONFields: - return await JSONFields.get(data__filter=kwargs) - - async def test_json_in(self): - self.assertEqual(await self.get_by_data_filter(val__in=["word1", "word2"]), self.obj) - self.assertEqual(await self.get_by_data_filter(val__not_in=["word3", "word4"]), self.obj) - - with self.assertRaises(DoesNotExist): - await self.get_by_data_filter(val__in=["doesnotexist"]) - - async def test_json_defaults(self): - self.assertEqual(await self.get_by_data_filter(val__not="word2"), self.obj) - self.assertEqual(await self.get_by_data_filter(val__isnull=False), self.obj) - self.assertEqual(await self.get_by_data_filter(val__not_isnull=True), self.obj) - - async def test_json_int_comparisons(self): - self.assertEqual(await self.get_by_data_filter(int_val=123), self.obj) - self.assertEqual(await self.get_by_data_filter(int_val__gt=100), self.obj) - self.assertEqual(await self.get_by_data_filter(int_val__gte=100), self.obj) - self.assertEqual(await self.get_by_data_filter(int_val__lt=200), self.obj) - self.assertEqual(await self.get_by_data_filter(int_val__lte=200), self.obj) - self.assertEqual(await self.get_by_data_filter(int_val__range=[100, 200]), self.obj) - - with self.assertRaises(DoesNotExist): - await self.get_by_data_filter(int_val__gt=1000) - - async def test_json_float_comparisons(self): - self.assertEqual(await self.get_by_data_filter(float_val__gt=100.0), self.obj) - self.assertEqual(await self.get_by_data_filter(float_val__gte=100.0), self.obj) - self.assertEqual(await self.get_by_data_filter(float_val__lt=200.0), self.obj) - self.assertEqual(await self.get_by_data_filter(float_val__lte=200.0), self.obj) - self.assertEqual(await self.get_by_data_filter(float_val__range=[100.0, 200.0]), self.obj) - - with self.assertRaises(DoesNotExist): - await self.get_by_data_filter(int_val__gt=1000.0) - - async def test_json_string_comparisons(self): - self.assertEqual(await self.get_by_data_filter(val__contains="ord"), self.obj) - self.assertEqual(await self.get_by_data_filter(val__icontains="OrD"), self.obj) - self.assertEqual(await self.get_by_data_filter(val__startswith="wor"), self.obj) - self.assertEqual(await self.get_by_data_filter(val__istartswith="wOr"), self.obj) - self.assertEqual(await self.get_by_data_filter(val__endswith="rd1"), self.obj) - self.assertEqual(await self.get_by_data_filter(val__iendswith="Rd1"), self.obj) - self.assertEqual(await self.get_by_data_filter(val__iexact="wOrD1"), self.obj) - - with self.assertRaises(DoesNotExist): - await self.get_by_data_filter(val__contains="doesnotexist") - - async def test_date_comparisons(self): - self.assertEqual( - await self.get_by_data_filter(date_val=datetime(1970, 1, 1, 12, 36, 59, 123456)), - self.obj, - ) - self.assertEqual(await self.get_by_data_filter(date_val__year=1970), self.obj) - self.assertEqual(await self.get_by_data_filter(date_val__month=1), self.obj) - self.assertEqual(await self.get_by_data_filter(date_val__day=1), self.obj) - self.assertEqual(await self.get_by_data_filter(date_val__hour=12), self.obj) - self.assertEqual(await self.get_by_data_filter(date_val__minute=36), self.obj) - self.assertEqual( - await self.get_by_data_filter(date_val__second=Decimal("59.123456")), self.obj - ) - self.assertEqual(await self.get_by_data_filter(date_val__microsecond=59123456), self.obj) - - async def test_json_list(self): - self.assertEqual(await self.get_by_data_filter(int_list__0__gt=0), self.obj) - self.assertEqual(await self.get_by_data_filter(int_list__0__lt=2), self.obj) - - with self.assertRaises(DoesNotExist): - await self.get_by_data_filter(int_list__0__range=(20, 30)) - - async def test_nested(self): - self.assertEqual(await self.get_by_data_filter(nested__val="word2"), self.obj) - self.assertEqual(await self.get_by_data_filter(nested__int_val=456), self.obj) - self.assertEqual( - await self.get_by_data_filter( - nested__date_val=datetime(1970, 1, 1, 12, 36, 59, 123456) - ), - self.obj, - ) - self.assertEqual(await self.get_by_data_filter(nested__val__icontains="orD"), self.obj) - self.assertEqual(await self.get_by_data_filter(nested__int_val__gte=400), self.obj) - self.assertEqual(await self.get_by_data_filter(nested__date_val__year=1970), self.obj) - self.assertEqual(await self.get_by_data_filter(nested__date_val__month=1), self.obj) - self.assertEqual(await self.get_by_data_filter(nested__date_val__day=1), self.obj) - self.assertEqual(await self.get_by_data_filter(nested__date_val__hour=12), self.obj) - self.assertEqual(await self.get_by_data_filter(nested__date_val__minute=36), self.obj) - self.assertEqual( - await self.get_by_data_filter(nested__date_val__second=Decimal("59.123456")), self.obj - ) - self.assertEqual( - await self.get_by_data_filter(nested__date_val__microsecond=59123456), self.obj - ) - self.assertEqual(await self.get_by_data_filter(nested__val__iexact="wOrD2"), self.obj) - self.assertEqual(await self.get_by_data_filter(nested__int_val__lt=500), self.obj) - self.assertEqual(await self.get_by_data_filter(nested__date_val__year=1970), self.obj) - self.assertEqual(await self.get_by_data_filter(nested__date_val__month=1), self.obj) - self.assertEqual(await self.get_by_data_filter(nested__date_val__day=1), self.obj) - self.assertEqual(await self.get_by_data_filter(nested__date_val__hour=12), self.obj) - self.assertEqual(await self.get_by_data_filter(nested__date_val__minute=36), self.obj) - self.assertEqual( - await self.get_by_data_filter(nested__date_val__second=Decimal("59.123456")), self.obj - ) - self.assertEqual( - await self.get_by_data_filter(nested__date_val__microsecond=59123456), self.obj + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_json_in(json_obj): + assert await get_by_data_filter(json_obj, val__in=["word1", "word2"]) == json_obj + assert await get_by_data_filter(json_obj, val__not_in=["word3", "word4"]) == json_obj + + with pytest.raises(DoesNotExist): + await get_by_data_filter(json_obj, val__in=["doesnotexist"]) + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_json_defaults(json_obj): + assert await get_by_data_filter(json_obj, val__not="word2") == json_obj + assert await get_by_data_filter(json_obj, val__isnull=False) == json_obj + assert await get_by_data_filter(json_obj, val__not_isnull=True) == json_obj + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_json_int_comparisons(json_obj): + assert await get_by_data_filter(json_obj, int_val=123) == json_obj + assert await get_by_data_filter(json_obj, int_val__gt=100) == json_obj + assert await get_by_data_filter(json_obj, int_val__gte=100) == json_obj + assert await get_by_data_filter(json_obj, int_val__lt=200) == json_obj + assert await get_by_data_filter(json_obj, int_val__lte=200) == json_obj + assert await get_by_data_filter(json_obj, int_val__range=[100, 200]) == json_obj + + with pytest.raises(DoesNotExist): + await get_by_data_filter(json_obj, int_val__gt=1000) + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_json_float_comparisons(json_obj): + assert await get_by_data_filter(json_obj, float_val__gt=100.0) == json_obj + assert await get_by_data_filter(json_obj, float_val__gte=100.0) == json_obj + assert await get_by_data_filter(json_obj, float_val__lt=200.0) == json_obj + assert await get_by_data_filter(json_obj, float_val__lte=200.0) == json_obj + assert await get_by_data_filter(json_obj, float_val__range=[100.0, 200.0]) == json_obj + + with pytest.raises(DoesNotExist): + await get_by_data_filter(json_obj, int_val__gt=1000.0) + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_json_string_comparisons(json_obj): + assert await get_by_data_filter(json_obj, val__contains="ord") == json_obj + assert await get_by_data_filter(json_obj, val__icontains="OrD") == json_obj + assert await get_by_data_filter(json_obj, val__startswith="wor") == json_obj + assert await get_by_data_filter(json_obj, val__istartswith="wOr") == json_obj + assert await get_by_data_filter(json_obj, val__endswith="rd1") == json_obj + assert await get_by_data_filter(json_obj, val__iendswith="Rd1") == json_obj + assert await get_by_data_filter(json_obj, val__iexact="wOrD1") == json_obj + + with pytest.raises(DoesNotExist): + await get_by_data_filter(json_obj, val__contains="doesnotexist") + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_date_comparisons(json_obj): + assert ( + await get_by_data_filter(json_obj, date_val=datetime(1970, 1, 1, 12, 36, 59, 123456)) + == json_obj + ) + assert await get_by_data_filter(json_obj, date_val__year=1970) == json_obj + assert await get_by_data_filter(json_obj, date_val__month=1) == json_obj + assert await get_by_data_filter(json_obj, date_val__day=1) == json_obj + assert await get_by_data_filter(json_obj, date_val__hour=12) == json_obj + assert await get_by_data_filter(json_obj, date_val__minute=36) == json_obj + assert await get_by_data_filter(json_obj, date_val__second=Decimal("59.123456")) == json_obj + assert await get_by_data_filter(json_obj, date_val__microsecond=59123456) == json_obj + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_json_list(json_obj): + assert await get_by_data_filter(json_obj, int_list__0__gt=0) == json_obj + assert await get_by_data_filter(json_obj, int_list__0__lt=2) == json_obj + + with pytest.raises(DoesNotExist): + await get_by_data_filter(json_obj, int_list__0__range=(20, 30)) + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_nested(json_obj): + assert await get_by_data_filter(json_obj, nested__val="word2") == json_obj + assert await get_by_data_filter(json_obj, nested__int_val=456) == json_obj + assert ( + await get_by_data_filter( + json_obj, nested__date_val=datetime(1970, 1, 1, 12, 36, 59, 123456) ) - self.assertEqual(await self.get_by_data_filter(nested__val__iexact="wOrD2"), self.obj) + == json_obj + ) + assert await get_by_data_filter(json_obj, nested__val__icontains="orD") == json_obj + assert await get_by_data_filter(json_obj, nested__int_val__gte=400) == json_obj + assert await get_by_data_filter(json_obj, nested__date_val__year=1970) == json_obj + assert await get_by_data_filter(json_obj, nested__date_val__month=1) == json_obj + assert await get_by_data_filter(json_obj, nested__date_val__day=1) == json_obj + assert await get_by_data_filter(json_obj, nested__date_val__hour=12) == json_obj + assert await get_by_data_filter(json_obj, nested__date_val__minute=36) == json_obj + assert ( + await get_by_data_filter(json_obj, nested__date_val__second=Decimal("59.123456")) + == json_obj + ) + assert await get_by_data_filter(json_obj, nested__date_val__microsecond=59123456) == json_obj + assert await get_by_data_filter(json_obj, nested__val__iexact="wOrD2") == json_obj + assert await get_by_data_filter(json_obj, nested__int_val__lt=500) == json_obj + assert await get_by_data_filter(json_obj, nested__date_val__year=1970) == json_obj + assert await get_by_data_filter(json_obj, nested__date_val__month=1) == json_obj + assert await get_by_data_filter(json_obj, nested__date_val__day=1) == json_obj + assert await get_by_data_filter(json_obj, nested__date_val__hour=12) == json_obj + assert await get_by_data_filter(json_obj, nested__date_val__minute=36) == json_obj + assert ( + await get_by_data_filter(json_obj, nested__date_val__second=Decimal("59.123456")) + == json_obj + ) + assert await get_by_data_filter(json_obj, nested__date_val__microsecond=59123456) == json_obj + assert await get_by_data_filter(json_obj, nested__val__iexact="wOrD2") == json_obj - async def test_nested_nested(self): - self.assertEqual(await self.get_by_data_filter(nested__nested__val="word3"), self.obj) + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_nested_nested(json_obj): + assert await get_by_data_filter(json_obj, nested__nested__val="word3") == json_obj diff --git a/tests/contrib/postgres/test_search.py b/tests/contrib/postgres/test_search.py index 8fee8caf5..25ef289a7 100644 --- a/tests/contrib/postgres/test_search.py +++ b/tests/contrib/postgres/test_search.py @@ -1,10 +1,11 @@ import os +import pytest + from tests.contrib.postgres.models_tsvector import TSVectorEntry from tests.testmodels import TextFields from tortoise import connections from tortoise.backends.psycopg.client import PsycopgClient -from tortoise.contrib import test from tortoise.contrib.postgres.search import ( Lexeme, SearchHeadline, @@ -12,171 +13,225 @@ SearchRank, SearchVector, ) - - -class TestPostgresSearchExpressions(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - self.db = connections.get("models") - self.is_psycopg = isinstance(self.db, PsycopgClient) - - def assertSql(self, sql: str, expected_psycopg: str, expected_asyncpg: str) -> None: - expected = expected_psycopg if self.is_psycopg else expected_asyncpg - self.assertEqual(sql, expected) - - @test.requireCapability(dialect="postgres") - def test_search_vector(self): - sql = TextFields.all().annotate(search=SearchVector("text")).values("search").sql() - self.assertSql( - sql, - 'SELECT TO_TSVECTOR("text") "search" FROM "textfields"', - 'SELECT TO_TSVECTOR("text") "search" FROM "textfields"', - ) - - @test.requireCapability(dialect="postgres") - def test_search_vector_config_weight(self): - sql = ( - TextFields.all() - .annotate(search=SearchVector("text", config="english", weight="A")) - .values("search") - .sql() - ) - self.assertSql( - sql, - 'SELECT SETWEIGHT(TO_TSVECTOR(%s,"text"),%s) "search" FROM "textfields"', - 'SELECT SETWEIGHT(TO_TSVECTOR($1,"text"),$2) "search" FROM "textfields"', - ) - - @test.requireCapability(dialect="postgres") - def test_search_query_types(self): - sql = ( - TextFields.all() - .annotate(query=SearchQuery("fat", search_type="phrase")) - .values("query") - .sql() - ) - self.assertSql( - sql, - 'SELECT PHRASETO_TSQUERY(%s) "query" FROM "textfields"', - 'SELECT PHRASETO_TSQUERY($1) "query" FROM "textfields"', - ) - - @test.requireCapability(dialect="postgres") - def test_search_query_combine_and_invert(self): - query = SearchQuery("fat") & SearchQuery("rat") - sql = TextFields.all().annotate(query=query).values("query").sql() - self.assertSql( - sql, - 'SELECT (PLAINTO_TSQUERY(%s) && PLAINTO_TSQUERY(%s)) "query" FROM "textfields"', - 'SELECT (PLAINTO_TSQUERY($1) && PLAINTO_TSQUERY($2)) "query" FROM "textfields"', - ) - - sql = TextFields.all().annotate(query=SearchQuery("fat", invert=True)).values("query").sql() - self.assertSql( - sql, - 'SELECT !!(PLAINTO_TSQUERY(%s)) "query" FROM "textfields"', - 'SELECT !!(PLAINTO_TSQUERY($1)) "query" FROM "textfields"', - ) - - @test.requireCapability(dialect="postgres") - def test_search_query_lexeme(self): - lexeme_query = SearchQuery(Lexeme("fat") & Lexeme("rat")) - sql = TextFields.all().annotate(query=lexeme_query).values("query").sql() - self.assertSql( - sql, - 'SELECT TO_TSQUERY(%s) "query" FROM "textfields"', - 'SELECT TO_TSQUERY($1) "query" FROM "textfields"', - ) - - @test.requireCapability(dialect="postgres") - def test_search_rank(self): - sql = ( - TextFields.all() - .annotate(rank=SearchRank(SearchVector("text"), SearchQuery("fat"))) - .values("rank") - .sql() - ) - self.assertSql( - sql, - 'SELECT TS_RANK(TO_TSVECTOR("text"),PLAINTO_TSQUERY(%s)) "rank" FROM "textfields"', - 'SELECT TS_RANK(TO_TSVECTOR("text"),PLAINTO_TSQUERY($1)) "rank" FROM "textfields"', - ) - - @test.requireCapability(dialect="postgres") - def test_search_headline(self): - sql = ( - TextFields.all() - .annotate( - headline=SearchHeadline( - "text", - SearchQuery("fat"), - start_sel="", - stop_sel="", - ) +from tortoise.contrib.test import requireCapability + + +def skip_if_not_postgres(): + """Skip test if not running against PostgreSQL.""" + db_url = os.getenv("TORTOISE_TEST_DB", "") + if db_url.split(":", 1)[0] not in {"postgres", "asyncpg", "psycopg"}: + pytest.skip("Postgres-only test.") + + +def assert_sql(db, sql: str, expected_psycopg: str, expected_asyncpg: str) -> None: + """Assert SQL matches expected value based on db client type.""" + is_psycopg = isinstance(db, PsycopgClient) + expected = expected_psycopg if is_psycopg else expected_asyncpg + assert sql == expected + + +# ============================================================================= +# TestPostgresSearchExpressions - uses standard testmodels (test.TestCase equivalent) +# ============================================================================= + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_search_vector(db_postgres): + """Test SearchVector expression.""" + db = connections.get("models") + sql = TextFields.all().annotate(search=SearchVector("text")).values("search").sql() + assert_sql( + db, + sql, + 'SELECT TO_TSVECTOR("text") "search" FROM "textfields"', + 'SELECT TO_TSVECTOR("text") "search" FROM "textfields"', + ) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_search_vector_config_weight(db_postgres): + """Test SearchVector with config and weight.""" + db = connections.get("models") + sql = ( + TextFields.all() + .annotate(search=SearchVector("text", config="english", weight="A")) + .values("search") + .sql() + ) + assert_sql( + db, + sql, + 'SELECT SETWEIGHT(TO_TSVECTOR(%s,"text"),%s) "search" FROM "textfields"', + 'SELECT SETWEIGHT(TO_TSVECTOR($1,"text"),$2) "search" FROM "textfields"', + ) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_search_query_types(db_postgres): + """Test SearchQuery with different search types.""" + db = connections.get("models") + sql = ( + TextFields.all() + .annotate(query=SearchQuery("fat", search_type="phrase")) + .values("query") + .sql() + ) + assert_sql( + db, + sql, + 'SELECT PHRASETO_TSQUERY(%s) "query" FROM "textfields"', + 'SELECT PHRASETO_TSQUERY($1) "query" FROM "textfields"', + ) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_search_query_combine_and_invert(db_postgres): + """Test SearchQuery combine and invert operations.""" + db = connections.get("models") + query = SearchQuery("fat") & SearchQuery("rat") + sql = TextFields.all().annotate(query=query).values("query").sql() + assert_sql( + db, + sql, + 'SELECT (PLAINTO_TSQUERY(%s) && PLAINTO_TSQUERY(%s)) "query" FROM "textfields"', + 'SELECT (PLAINTO_TSQUERY($1) && PLAINTO_TSQUERY($2)) "query" FROM "textfields"', + ) + + sql = TextFields.all().annotate(query=SearchQuery("fat", invert=True)).values("query").sql() + assert_sql( + db, + sql, + 'SELECT !!(PLAINTO_TSQUERY(%s)) "query" FROM "textfields"', + 'SELECT !!(PLAINTO_TSQUERY($1)) "query" FROM "textfields"', + ) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_search_query_lexeme(db_postgres): + """Test SearchQuery with Lexeme.""" + db = connections.get("models") + lexeme_query = SearchQuery(Lexeme("fat") & Lexeme("rat")) + sql = TextFields.all().annotate(query=lexeme_query).values("query").sql() + assert_sql( + db, + sql, + 'SELECT TO_TSQUERY(%s) "query" FROM "textfields"', + 'SELECT TO_TSQUERY($1) "query" FROM "textfields"', + ) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_search_rank(db_postgres): + """Test SearchRank expression.""" + db = connections.get("models") + sql = ( + TextFields.all() + .annotate(rank=SearchRank(SearchVector("text"), SearchQuery("fat"))) + .values("rank") + .sql() + ) + assert_sql( + db, + sql, + 'SELECT TS_RANK(TO_TSVECTOR("text"),PLAINTO_TSQUERY(%s)) "rank" FROM "textfields"', + 'SELECT TS_RANK(TO_TSVECTOR("text"),PLAINTO_TSQUERY($1)) "rank" FROM "textfields"', + ) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_search_headline(db_postgres): + """Test SearchHeadline expression.""" + db = connections.get("models") + sql = ( + TextFields.all() + .annotate( + headline=SearchHeadline( + "text", + SearchQuery("fat"), + start_sel="", + stop_sel="", ) - .values("headline") - .sql() - ) - self.assertSql( - sql, - 'SELECT TS_HEADLINE("text",PLAINTO_TSQUERY(%s),%s) "headline" FROM "textfields"', - 'SELECT TS_HEADLINE("text",PLAINTO_TSQUERY($1),$2) "headline" FROM "textfields"', - ) - - @test.requireCapability(dialect="postgres") - def test_search_lookup_text(self): - sql = TextFields.filter(text__search="fat").values("id").sql() - self.assertSql( - sql, - 'SELECT "id" "id" FROM "textfields" WHERE TO_TSVECTOR("text") @@ PLAINTO_TSQUERY(%s)', - 'SELECT "id" "id" FROM "textfields" WHERE TO_TSVECTOR("text") @@ PLAINTO_TSQUERY($1)', - ) - - @test.requireCapability(dialect="postgres") - def test_search_lookup_text_searchquery(self): - sql = ( - TextFields.filter(text__search=SearchQuery("fat", search_type="raw")).values("id").sql() - ) - self.assertSql( - sql, - 'SELECT "id" "id" FROM "textfields" WHERE TO_TSVECTOR("text") @@ TO_TSQUERY(%s)', - 'SELECT "id" "id" FROM "textfields" WHERE TO_TSVECTOR("text") @@ TO_TSQUERY($1)', - ) - - -class TestPostgresSearchLookupTSVector(test.IsolatedTestCase): - tortoise_test_modules = ["tests.contrib.postgres.models_tsvector"] - - async def asyncSetUp(self): - db_url = os.getenv("TORTOISE_TEST_DB", "") - if db_url.split(":", 1)[0] not in {"postgres", "asyncpg", "psycopg"}: - raise test.SkipTest("Postgres-only test.") - await super().asyncSetUp() - self.db = connections.get("models") - self.is_psycopg = isinstance(self.db, PsycopgClient) - - def assertSql(self, sql: str, expected_psycopg: str, expected_asyncpg: str) -> None: - expected = expected_psycopg if self.is_psycopg else expected_asyncpg - self.assertEqual(sql, expected) - - @test.requireCapability(dialect="postgres") - def test_search_lookup_tsvector(self): - sql = TSVectorEntry.filter(search_vector__search="fat").values("id").sql() - self.assertSql( - sql, - 'SELECT "id" "id" FROM "tsvector_entry" WHERE "search_vector" @@ PLAINTO_TSQUERY(%s)', - 'SELECT "id" "id" FROM "tsvector_entry" WHERE "search_vector" @@ PLAINTO_TSQUERY($1)', - ) - - @test.requireCapability(dialect="postgres") - def test_search_lookup_tsvector_searchquery(self): - sql = ( - TSVectorEntry.filter(search_vector__search=SearchQuery("fat", search_type="raw")) - .values("id") - .sql() - ) - self.assertSql( - sql, - 'SELECT "id" "id" FROM "tsvector_entry" WHERE "search_vector" @@ TO_TSQUERY(%s)', - 'SELECT "id" "id" FROM "tsvector_entry" WHERE "search_vector" @@ TO_TSQUERY($1)', ) + .values("headline") + .sql() + ) + assert_sql( + db, + sql, + 'SELECT TS_HEADLINE("text",PLAINTO_TSQUERY(%s),%s) "headline" FROM "textfields"', + 'SELECT TS_HEADLINE("text",PLAINTO_TSQUERY($1),$2) "headline" FROM "textfields"', + ) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_search_lookup_text(db_postgres): + """Test search lookup on text field.""" + db = connections.get("models") + sql = TextFields.filter(text__search="fat").values("id").sql() + assert_sql( + db, + sql, + 'SELECT "id" "id" FROM "textfields" WHERE TO_TSVECTOR("text") @@ PLAINTO_TSQUERY(%s)', + 'SELECT "id" "id" FROM "textfields" WHERE TO_TSVECTOR("text") @@ PLAINTO_TSQUERY($1)', + ) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_search_lookup_text_searchquery(db_postgres): + """Test search lookup with SearchQuery on text field.""" + db = connections.get("models") + sql = TextFields.filter(text__search=SearchQuery("fat", search_type="raw")).values("id").sql() + assert_sql( + db, + sql, + 'SELECT "id" "id" FROM "textfields" WHERE TO_TSVECTOR("text") @@ TO_TSQUERY(%s)', + 'SELECT "id" "id" FROM "textfields" WHERE TO_TSVECTOR("text") @@ TO_TSQUERY($1)', + ) + + +# ============================================================================= +# TestPostgresSearchLookupTSVector - uses TSVector models (IsolatedTestCase equivalent) +# ============================================================================= + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_search_lookup_tsvector(db_search): + """Test search lookup on TSVector field.""" + skip_if_not_postgres() + db = connections.get("models") + sql = TSVectorEntry.filter(search_vector__search="fat").values("id").sql() + assert_sql( + db, + sql, + 'SELECT "id" "id" FROM "tsvector_entry" WHERE "search_vector" @@ PLAINTO_TSQUERY(%s)', + 'SELECT "id" "id" FROM "tsvector_entry" WHERE "search_vector" @@ PLAINTO_TSQUERY($1)', + ) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_search_lookup_tsvector_searchquery(db_search): + """Test search lookup with SearchQuery on TSVector field.""" + skip_if_not_postgres() + db = connections.get("models") + sql = ( + TSVectorEntry.filter(search_vector__search=SearchQuery("fat", search_type="raw")) + .values("id") + .sql() + ) + assert_sql( + db, + sql, + 'SELECT "id" "id" FROM "tsvector_entry" WHERE "search_vector" @@ TO_TSQUERY(%s)', + 'SELECT "id" "id" FROM "tsvector_entry" WHERE "search_vector" @@ TO_TSQUERY($1)', + ) diff --git a/tests/contrib/postgres/test_tsvector_field.py b/tests/contrib/postgres/test_tsvector_field.py index 2c7721b13..531b6e9e1 100644 --- a/tests/contrib/postgres/test_tsvector_field.py +++ b/tests/contrib/postgres/test_tsvector_field.py @@ -1,23 +1,24 @@ import os +import pytest + from tests.contrib.postgres.models_tsvector import TSVectorEntry -from tortoise.contrib import test -class TestTSVectorField(test.IsolatedTestCase): - tortoise_test_modules = ["tests.contrib.postgres.models_tsvector"] +def skip_if_not_postgres(): + """Skip test if not running against PostgreSQL.""" + db_url = os.getenv("TORTOISE_TEST_DB", "") + if db_url.split(":", 1)[0] not in {"postgres", "asyncpg", "psycopg"}: + pytest.skip("Postgres-only test.") - async def asyncSetUp(self) -> None: - db_url = os.getenv("TORTOISE_TEST_DB", "") - if db_url.split(":", 1)[0] not in {"postgres", "asyncpg", "psycopg"}: - raise test.SkipTest("Postgres-only test.") - await super().asyncSetUp() - def test_tsvector_generated_sql(self) -> None: - field = TSVectorEntry._meta.fields_map["search_vector"] - sql = field.get_for_dialect("postgres", "GENERATED_SQL") - self.assertEqual( - sql, - "GENERATED ALWAYS AS (SETWEIGHT(TO_TSVECTOR('english',COALESCE(\"title\", '')),'A')" - " || SETWEIGHT(TO_TSVECTOR('english',COALESCE(\"body\", '')),'B')) STORED", - ) +@pytest.mark.asyncio +async def test_tsvector_generated_sql(db_tsvector): + """Test TSVector field generates correct SQL.""" + skip_if_not_postgres() + field = TSVectorEntry._meta.fields_map["search_vector"] + sql = field.get_for_dialect("postgres", "GENERATED_SQL") + assert sql == ( + "GENERATED ALWAYS AS (SETWEIGHT(TO_TSVECTOR('english',COALESCE(\"title\", '')),'A')" + " || SETWEIGHT(TO_TSVECTOR('english',COALESCE(\"body\", '')),'B')) STORED" + ) diff --git a/tests/contrib/test_decorator.py b/tests/contrib/test_decorator.py index e05907a53..e8f03e51e 100644 --- a/tests/contrib/test_decorator.py +++ b/tests/contrib/test_decorator.py @@ -1,75 +1,93 @@ +import os import subprocess # nosec +import sys from unittest.mock import AsyncMock, patch -from tortoise.contrib import test -from tortoise.contrib.test import init_memory_sqlite - - -class TestDecorator(test.TestCase): - @test.requireCapability(dialect="sqlite") - async def test_script_with_init_memory_sqlite(self) -> None: - r = subprocess.run(["python", "examples/basic.py"], capture_output=True, text=True) # nosec - assert not r.stderr - output = r.stdout - s = "[{'id': 1, 'name': 'Updated name'}, {'id': 2, 'name': 'Test 2'}]" - self.assertIn(s, output) - - @test.requireCapability(dialect="sqlite") - @patch("tortoise.Tortoise.init") - @patch("tortoise.Tortoise.generate_schemas") - async def test_init_memory_sqlite( - self, - mocked_generate: AsyncMock, - mocked_init: AsyncMock, - ) -> None: - @init_memory_sqlite - async def run(): - return "foo" - - res = await run() - self.assertEqual(res, "foo") - mocked_init.assert_awaited_once() - mocked_init.assert_called_once_with( - db_url="sqlite://:memory:", modules={"models": ["__main__"]} - ) - mocked_generate.assert_awaited_once() - - @test.requireCapability(dialect="sqlite") - @patch("tortoise.Tortoise.init") - @patch("tortoise.Tortoise.generate_schemas") - async def test_init_memory_sqlite_with_models( - self, - mocked_generate: AsyncMock, - mocked_init: AsyncMock, - ) -> None: - @init_memory_sqlite(["app.models"]) - async def run(): - return "foo" - - res = await run() - self.assertEqual(res, "foo") - mocked_init.assert_awaited_once() - mocked_init.assert_called_once_with( - db_url="sqlite://:memory:", modules={"models": ["app.models"]} - ) - mocked_generate.assert_awaited_once() - - @test.requireCapability(dialect="sqlite") - @patch("tortoise.Tortoise.init") - @patch("tortoise.Tortoise.generate_schemas") - async def test_init_memory_sqlite_model_str( - self, - mocked_generate: AsyncMock, - mocked_init: AsyncMock, - ) -> None: - @init_memory_sqlite("app.models") - async def run(): - return "foo" - - res = await run() - self.assertEqual(res, "foo") - mocked_init.assert_awaited_once() - mocked_init.assert_called_once_with( - db_url="sqlite://:memory:", modules={"models": ["app.models"]} - ) - mocked_generate.assert_awaited_once() +import pytest + +from tortoise.contrib.test import init_memory_sqlite, requireCapability + + +@pytest.mark.asyncio +@requireCapability(dialect="sqlite") +async def test_basic_example_script(db) -> None: + """Test that the basic example script runs successfully.""" + # Set PYTHONPATH to use local source instead of installed package + env = os.environ.copy() + env["PYTHONPATH"] = os.getcwd() + r = subprocess.run( # nosec + [sys.executable, "examples/basic.py"], capture_output=True, text=True, env=env + ) + assert not r.stderr, f"Script had errors: {r.stderr}" + output = r.stdout + s = "[{'id': 1, 'name': 'Updated name'}, {'id': 2, 'name': 'Test 2'}]" + assert s in output + + +@pytest.mark.asyncio +@requireCapability(dialect="sqlite") +@patch("tortoise.Tortoise.init") +@patch("tortoise.Tortoise.generate_schemas") +async def test_init_memory_sqlite_decorator( + mocked_generate: AsyncMock, + mocked_init: AsyncMock, + db, +) -> None: + """Test init_memory_sqlite as decorator without parentheses.""" + + @init_memory_sqlite + async def run(): + return "result" + + result = await run() + assert result == "result" + mocked_init.assert_awaited_once_with( + db_url="sqlite://:memory:", modules={"models": ["__main__"]} + ) + mocked_generate.assert_awaited_once() + + +@pytest.mark.asyncio +@requireCapability(dialect="sqlite") +@patch("tortoise.Tortoise.init") +@patch("tortoise.Tortoise.generate_schemas") +async def test_init_memory_sqlite_decorator_with_models_list( + mocked_generate: AsyncMock, + mocked_init: AsyncMock, + db, +) -> None: + """Test init_memory_sqlite as decorator with models list.""" + + @init_memory_sqlite(["app.models"]) + async def run(): + return "result" + + result = await run() + assert result == "result" + mocked_init.assert_awaited_once_with( + db_url="sqlite://:memory:", modules={"models": ["app.models"]} + ) + mocked_generate.assert_awaited_once() + + +@pytest.mark.asyncio +@requireCapability(dialect="sqlite") +@patch("tortoise.Tortoise.init") +@patch("tortoise.Tortoise.generate_schemas") +async def test_init_memory_sqlite_decorator_with_models_string( + mocked_generate: AsyncMock, + mocked_init: AsyncMock, + db, +) -> None: + """Test init_memory_sqlite as decorator with models string.""" + + @init_memory_sqlite("app.models") + async def run(): + return "result" + + result = await run() + assert result == "result" + mocked_init.assert_awaited_once_with( + db_url="sqlite://:memory:", modules={"models": ["app.models"]} + ) + mocked_generate.assert_awaited_once() diff --git a/tests/contrib/test_fastapi.py b/tests/contrib/test_fastapi.py index fee280414..fb396fd9e 100644 --- a/tests/contrib/test_fastapi.py +++ b/tests/contrib/test_fastapi.py @@ -1,35 +1,37 @@ from unittest.mock import AsyncMock, patch +import pytest from fastapi import FastAPI from tortoise.contrib import test from tortoise.contrib.fastapi import RegisterTortoise -class TestRegisterTortoise(test.TestCase): - @test.requireCapability(dialect="sqlite") - @patch("tortoise.Tortoise.init") - @patch("tortoise.connections.close_all") - async def test_await( - self, - mocked_close: AsyncMock, - mocked_init: AsyncMock, - ) -> None: - app = FastAPI() - orm = await RegisterTortoise( - app, - db_url="sqlite://:memory:", - modules={"models": ["__main__"]}, - ) - mocked_init.assert_awaited_once() - mocked_init.assert_called_once_with( - config=None, - config_file=None, - db_url="sqlite://:memory:", - modules={"models": ["__main__"]}, - use_tz=False, - timezone="UTC", - _create_db=False, - ) - await orm.close_orm() - mocked_close.assert_awaited_once() +@pytest.mark.asyncio +@test.requireCapability(dialect="sqlite") +@patch("tortoise.Tortoise.init") +@patch("tortoise.Tortoise.close_connections") +async def test_await( + mocked_close_connections: AsyncMock, + mocked_init: AsyncMock, + db, +) -> None: + app = FastAPI() + orm = await RegisterTortoise( + app, + db_url="sqlite://:memory:", + modules={"models": ["__main__"]}, + ) + mocked_init.assert_awaited_once() + mocked_init.assert_called_once_with( + config=None, + config_file=None, + db_url="sqlite://:memory:", + modules={"models": ["__main__"]}, + use_tz=False, + timezone="UTC", + _create_db=False, + _enable_global_fallback=True, + ) + await orm.close_orm() + mocked_close_connections.assert_awaited_once() diff --git a/tests/contrib/test_functions.py b/tests/contrib/test_functions.py index 276dd360e..9bf6a538b 100644 --- a/tests/contrib/test_functions.py +++ b/tests/contrib/test_functions.py @@ -1,37 +1,45 @@ +import pytest +import pytest_asyncio + from tests.testmodels import IntFields -from tortoise import connections from tortoise.contrib import test from tortoise.contrib.mysql.functions import Rand from tortoise.contrib.postgres.functions import Random as PostgresRandom from tortoise.contrib.sqlite.functions import Random as SqliteRandom -class TestFunction(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - self.intfields = [await IntFields.create(intnum=val) for val in range(10)] - self.db = connections.get("models") - - @test.requireCapability(dialect="mysql") - async def test_mysql_func_rand(self): - sql = IntFields.all().annotate(randnum=Rand()).values("intnum", "randnum").sql() - expected_sql = "SELECT `intnum` `intnum`,RAND() `randnum` FROM `intfields`" - self.assertEqual(sql, expected_sql) - - @test.requireCapability(dialect="mysql") - async def test_mysql_func_rand_with_seed(self): - sql = IntFields.all().annotate(randnum=Rand(0)).values("intnum", "randnum").sql() - expected_sql = "SELECT `intnum` `intnum`,RAND(%s) `randnum` FROM `intfields`" - self.assertEqual(sql, expected_sql) - - @test.requireCapability(dialect="postgres") - async def test_postgres_func_rand(self): - sql = IntFields.all().annotate(randnum=PostgresRandom()).values("intnum", "randnum").sql() - expected_sql = 'SELECT "intnum" "intnum",RANDOM() "randnum" FROM "intfields"' - self.assertEqual(sql, expected_sql) - - @test.requireCapability(dialect="sqlite") - async def test_sqlite_func_rand(self): - sql = IntFields.all().annotate(randnum=SqliteRandom()).values("intnum", "randnum").sql() - expected_sql = 'SELECT "intnum" "intnum",RANDOM() "randnum" FROM "intfields"' - self.assertEqual(sql, expected_sql) +@pytest_asyncio.fixture +async def intfields(db): + return [await IntFields.create(intnum=val) for val in range(10)] + + +@pytest.mark.asyncio +@test.requireCapability(dialect="mysql") +async def test_mysql_func_rand(db, intfields): + sql = IntFields.all().annotate(randnum=Rand()).values("intnum", "randnum").sql() + expected_sql = "SELECT `intnum` `intnum`,RAND() `randnum` FROM `intfields`" + assert sql == expected_sql + + +@pytest.mark.asyncio +@test.requireCapability(dialect="mysql") +async def test_mysql_func_rand_with_seed(db, intfields): + sql = IntFields.all().annotate(randnum=Rand(0)).values("intnum", "randnum").sql() + expected_sql = "SELECT `intnum` `intnum`,RAND(%s) `randnum` FROM `intfields`" + assert sql == expected_sql + + +@pytest.mark.asyncio +@test.requireCapability(dialect="postgres") +async def test_postgres_func_rand(db, intfields): + sql = IntFields.all().annotate(randnum=PostgresRandom()).values("intnum", "randnum").sql() + expected_sql = 'SELECT "intnum" "intnum",RANDOM() "randnum" FROM "intfields"' + assert sql == expected_sql + + +@pytest.mark.asyncio +@test.requireCapability(dialect="sqlite") +async def test_sqlite_func_rand(db, intfields): + sql = IntFields.all().annotate(randnum=SqliteRandom()).values("intnum", "randnum").sql() + expected_sql = 'SELECT "intnum" "intnum",RANDOM() "randnum" FROM "intfields"' + assert sql == expected_sql diff --git a/tests/contrib/test_pydantic.py b/tests/contrib/test_pydantic.py index 56ebeaf37..55f03faed 100644 --- a/tests/contrib/test_pydantic.py +++ b/tests/contrib/test_pydantic.py @@ -1,6 +1,7 @@ import copy import pytest +import pytest_asyncio from pydantic import ConfigDict, ValidationError from tests.testmodels import ( @@ -19,7 +20,6 @@ User, json_pydantic_default, ) -from tortoise.contrib import test from tortoise.contrib.pydantic import ( PydanticModel, pydantic_model_creator, @@ -27,180 +27,280 @@ ) -class TestPydantic(test.TestCase): - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - self.Event_Pydantic = pydantic_model_creator(Event) - self.Event_Pydantic_List = pydantic_queryset_creator(Event) - self.Tournament_Pydantic = pydantic_model_creator(Tournament) - self.Team_Pydantic = pydantic_model_creator(Team) - self.Address_Pydantic = pydantic_model_creator(Address) - self.ModelTestPydanticMetaBackwardRelations1_Pydantic = pydantic_model_creator( - ModelTestPydanticMetaBackwardRelations1 - ) - self.ModelTestPydanticMetaBackwardRelations2_Pydantic = pydantic_model_creator( - ModelTestPydanticMetaBackwardRelations2 - ) - - class PydanticMetaOverride: - backward_relations = False - - self.Event_Pydantic_non_backward_from_override = pydantic_model_creator( - Event, meta_override=PydanticMetaOverride, name="Event_non_backward" - ) - - self.tournament = await Tournament.create(name="New Tournament") - self.reporter = await Reporter.create(name="The Reporter") - self.event = await Event.create( - name="Test", tournament=self.tournament, reporter=self.reporter - ) - self.event2 = await Event.create(name="Test2", tournament=self.tournament) - self.address = await Address.create(city="Santa Monica", street="Ocean", event=self.event) - self.team1 = await Team.create(name="Onesies") - self.team2 = await Team.create(name="T-Shirts") - await self.event.participants.add(self.team1, self.team2) - await self.event2.participants.add(self.team1, self.team2) - self.maxDiff = None - - async def test_backward_relations_with_meta_override(self): - event_schema = copy.deepcopy(dict(self.Event_Pydantic.model_json_schema())) - event_non_backward_schema_by_override = copy.deepcopy( - dict(self.Event_Pydantic_non_backward_from_override.model_json_schema()) - ) - self.assertTrue("address" in event_schema["properties"]) - self.assertFalse("address" in event_non_backward_schema_by_override["properties"]) - del event_schema["properties"]["address"] - self.assertEqual( - event_schema["properties"], event_non_backward_schema_by_override["properties"] - ) - - async def test_backward_relations_with_pydantic_meta(self): - test_model1_schema = ( - self.ModelTestPydanticMetaBackwardRelations1_Pydantic.model_json_schema() - ) - test_model2_schema = ( - self.ModelTestPydanticMetaBackwardRelations2_Pydantic.model_json_schema() - ) - self.assertTrue("threes" in test_model2_schema["properties"]) - self.assertFalse("threes" in test_model1_schema["properties"]) - del test_model2_schema["properties"]["threes"] - self.assertEqual(test_model2_schema["properties"], test_model1_schema["properties"]) - print(test_model2_schema) - - def test_event_schema(self): - self.assertEqual( - self.Event_Pydantic.model_json_schema(), - { - "$defs": { - "Address_e4rhju_leaf": { - "additionalProperties": False, - "properties": { - "city": {"maxLength": 64, "title": "City", "type": "string"}, - "street": {"maxLength": 128, "title": "Street", "type": "string"}, - "m2mwitho2opks": { - "items": {"$ref": "#/$defs/M2mWithO2oPk_leajz6_leaf"}, - "title": "M2Mwitho2Opks", - "type": "array", - }, - "event_id": { - "maximum": 9223372036854775807, - "minimum": -9223372036854775808, - "title": "Event Id", - "type": "integer", - }, - }, - "required": ["city", "street", "event_id", "m2mwitho2opks"], - "title": "Address", - "type": "object", +# Fixtures for TestPydantic +@pytest_asyncio.fixture +async def pydantic_setup(db): + """Setup for pydantic tests with models and data.""" + Event_Pydantic = pydantic_model_creator(Event) + Event_Pydantic_List = pydantic_queryset_creator(Event) + Tournament_Pydantic = pydantic_model_creator(Tournament) + Team_Pydantic = pydantic_model_creator(Team) + Address_Pydantic = pydantic_model_creator(Address) + ModelTestPydanticMetaBackwardRelations1_Pydantic = pydantic_model_creator( + ModelTestPydanticMetaBackwardRelations1 + ) + ModelTestPydanticMetaBackwardRelations2_Pydantic = pydantic_model_creator( + ModelTestPydanticMetaBackwardRelations2 + ) + + class PydanticMetaOverride: + backward_relations = False + + Event_Pydantic_non_backward_from_override = pydantic_model_creator( + Event, meta_override=PydanticMetaOverride, name="Event_non_backward" + ) + + tournament = await Tournament.create(name="New Tournament") + reporter = await Reporter.create(name="The Reporter") + event = await Event.create(name="Test", tournament=tournament, reporter=reporter) + event2 = await Event.create(name="Test2", tournament=tournament) + address = await Address.create(city="Santa Monica", street="Ocean", event=event) + team1 = await Team.create(name="Onesies") + team2 = await Team.create(name="T-Shirts") + await event.participants.add(team1, team2) + await event2.participants.add(team1, team2) + + return { + "Event_Pydantic": Event_Pydantic, + "Event_Pydantic_List": Event_Pydantic_List, + "Tournament_Pydantic": Tournament_Pydantic, + "Team_Pydantic": Team_Pydantic, + "Address_Pydantic": Address_Pydantic, + "ModelTestPydanticMetaBackwardRelations1_Pydantic": ModelTestPydanticMetaBackwardRelations1_Pydantic, + "ModelTestPydanticMetaBackwardRelations2_Pydantic": ModelTestPydanticMetaBackwardRelations2_Pydantic, + "Event_Pydantic_non_backward_from_override": Event_Pydantic_non_backward_from_override, + "tournament": tournament, + "reporter": reporter, + "event": event, + "event2": event2, + "address": address, + "team1": team1, + "team2": team2, + } + + +@pytest.mark.asyncio +async def test_backward_relations_with_meta_override(db, pydantic_setup): + Event_Pydantic = pydantic_setup["Event_Pydantic"] + Event_Pydantic_non_backward_from_override = pydantic_setup[ + "Event_Pydantic_non_backward_from_override" + ] + + event_schema = copy.deepcopy(dict(Event_Pydantic.model_json_schema())) + event_non_backward_schema_by_override = copy.deepcopy( + dict(Event_Pydantic_non_backward_from_override.model_json_schema()) + ) + assert "address" in event_schema["properties"] + assert "address" not in event_non_backward_schema_by_override["properties"] + del event_schema["properties"]["address"] + assert event_schema["properties"] == event_non_backward_schema_by_override["properties"] + + +@pytest.mark.asyncio +async def test_backward_relations_with_pydantic_meta(db, pydantic_setup): + ModelTestPydanticMetaBackwardRelations1_Pydantic = pydantic_setup[ + "ModelTestPydanticMetaBackwardRelations1_Pydantic" + ] + ModelTestPydanticMetaBackwardRelations2_Pydantic = pydantic_setup[ + "ModelTestPydanticMetaBackwardRelations2_Pydantic" + ] + + test_model1_schema = ModelTestPydanticMetaBackwardRelations1_Pydantic.model_json_schema() + test_model2_schema = ModelTestPydanticMetaBackwardRelations2_Pydantic.model_json_schema() + assert "threes" in test_model2_schema["properties"] + assert "threes" not in test_model1_schema["properties"] + del test_model2_schema["properties"]["threes"] + assert test_model2_schema["properties"] == test_model1_schema["properties"] + print(test_model2_schema) + + +@pytest.mark.asyncio +async def test_event_schema(db, pydantic_setup): + Event_Pydantic = pydantic_setup["Event_Pydantic"] + assert Event_Pydantic.model_json_schema() == { + "$defs": { + "Address_e4rhju_leaf": { + "additionalProperties": False, + "properties": { + "city": {"maxLength": 64, "title": "City", "type": "string"}, + "street": {"maxLength": 128, "title": "Street", "type": "string"}, + "m2mwitho2opks": { + "items": {"$ref": "#/$defs/M2mWithO2oPk_leajz6_leaf"}, + "title": "M2Mwitho2Opks", + "type": "array", }, - "M2mWithO2oPk_leajz6_leaf": { - "additionalProperties": False, - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"maxLength": 64, "title": "Name", "type": "string"}, - }, - "required": ["id", "name"], - "title": "M2mWithO2oPk", - "type": "object", - }, - "Reporter_fgnv33_leaf": { - "additionalProperties": False, - "description": "Whom is assigned as the reporter", - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"title": "Name", "type": "string"}, - }, - "required": ["id", "name"], - "title": "Reporter", - "type": "object", + "event_id": { + "maximum": 9223372036854775807, + "minimum": -9223372036854775808, + "title": "Event Id", + "type": "integer", + }, + }, + "required": ["city", "street", "event_id", "m2mwitho2opks"], + "title": "Address", + "type": "object", + }, + "M2mWithO2oPk_leajz6_leaf": { + "additionalProperties": False, + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"maxLength": 64, "title": "Name", "type": "string"}, + }, + "required": ["id", "name"], + "title": "M2mWithO2oPk", + "type": "object", + }, + "Reporter_fgnv33_leaf": { + "additionalProperties": False, + "description": "Whom is assigned as the reporter", + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"title": "Name", "type": "string"}, + }, + "required": ["id", "name"], + "title": "Reporter", + "type": "object", + }, + "Team_ip4pg6_leaf": { + "additionalProperties": False, + "description": "Team that is a playing", + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", }, - "Team_ip4pg6_leaf": { - "additionalProperties": False, - "description": "Team that is a playing", - "properties": { - "id": { + "name": {"title": "Name", "type": "string"}, + "alias": { + "anyOf": [ + { "maximum": 2147483647, "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"title": "Name", "type": "string"}, - "alias": { - "anyOf": [ - { - "maximum": 2147483647, - "minimum": -2147483648, - "type": "integer", - }, - {"type": "null"}, - ], - "default": None, - "nullable": True, - "title": "Alias", - }, - }, - "required": ["id", "name"], - "title": "Team", - "type": "object", - }, - "Tournament_5y7e7j_leaf": { - "additionalProperties": False, - "properties": { - "id": { - "maximum": 32767, - "minimum": -32768, - "title": "Id", "type": "integer", }, - "name": {"maxLength": 255, "title": "Name", "type": "string"}, - "desc": { - "anyOf": [{"type": "string"}, {"type": "null"}], - "default": None, - "nullable": True, - "title": "Desc", - }, - "created": { - "format": "date-time", - "readOnly": True, - "title": "Created", - "type": "string", - }, - }, - "required": ["id", "name", "created"], - "title": "Tournament", - "type": "object", + {"type": "null"}, + ], + "default": None, + "nullable": True, + "title": "Alias", + }, + }, + "required": ["id", "name"], + "title": "Team", + "type": "object", + }, + "Tournament_5y7e7j_leaf": { + "additionalProperties": False, + "properties": { + "id": { + "maximum": 32767, + "minimum": -32768, + "title": "Id", + "type": "integer", + }, + "name": {"maxLength": 255, "title": "Name", "type": "string"}, + "desc": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "nullable": True, + "title": "Desc", + }, + "created": { + "format": "date-time", + "readOnly": True, + "title": "Created", + "type": "string", }, }, + "required": ["id", "name", "created"], + "title": "Tournament", + "type": "object", + }, + }, + "additionalProperties": False, + "description": "Events on the calendar", + "properties": { + "event_id": { + "maximum": 9223372036854775807, + "minimum": -9223372036854775808, + "title": "Event Id", + "type": "integer", + }, + "name": {"description": "The name", "title": "Name", "type": "string"}, + "tournament": { + "$ref": "#/$defs/Tournament_5y7e7j_leaf", + "description": "What tournaments is a happenin'", + }, + "reporter": { + "anyOf": [ + {"$ref": "#/$defs/Reporter_fgnv33_leaf"}, + {"type": "null"}, + ], + "nullable": True, + "title": "Reporter", + }, + "participants": { + "items": {"$ref": "#/$defs/Team_ip4pg6_leaf"}, + "title": "Participants", + "type": "array", + }, + "modified": { + "format": "date-time", + "readOnly": True, + "title": "Modified", + "type": "string", + }, + "token": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Token"}, + "alias": { + "anyOf": [ + {"maximum": 2147483647, "minimum": -2147483648, "type": "integer"}, + {"type": "null"}, + ], + "default": None, + "nullable": True, + "title": "Alias", + }, + "address": { + "anyOf": [ + {"$ref": "#/$defs/Address_e4rhju_leaf"}, + {"type": "null"}, + ], + "nullable": True, + "title": "Address", + }, + }, + "required": [ + "event_id", + "name", + "tournament", + "reporter", + "participants", + "modified", + "token", + "address", + ], + "title": "Event", + "type": "object", + } + + +@pytest.mark.asyncio +async def test_eventlist_schema(db, pydantic_setup): + Event_Pydantic_List = pydantic_setup["Event_Pydantic_List"] + assert Event_Pydantic_List.model_json_schema() == { + "$defs": { + "Event_mfxmwb": { "additionalProperties": False, "description": "Events on the calendar", "properties": { @@ -234,10 +334,17 @@ def test_event_schema(self): "title": "Modified", "type": "string", }, - "token": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Token"}, + "token": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "Token", + }, "alias": { "anyOf": [ - {"maximum": 2147483647, "minimum": -2147483648, "type": "integer"}, + { + "maximum": 2147483647, + "minimum": -2147483648, + "type": "integer", + }, {"type": "null"}, ], "default": None, @@ -266,359 +373,7 @@ def test_event_schema(self): "title": "Event", "type": "object", }, - ) - - def test_eventlist_schema(self): - self.assertEqual( - self.Event_Pydantic_List.model_json_schema(), - { - "$defs": { - "Event_mfxmwb": { - "additionalProperties": False, - "description": "Events on the calendar", - "properties": { - "event_id": { - "maximum": 9223372036854775807, - "minimum": -9223372036854775808, - "title": "Event Id", - "type": "integer", - }, - "name": {"description": "The name", "title": "Name", "type": "string"}, - "tournament": { - "$ref": "#/$defs/Tournament_5y7e7j_leaf", - "description": "What tournaments is a happenin'", - }, - "reporter": { - "anyOf": [ - {"$ref": "#/$defs/Reporter_fgnv33_leaf"}, - {"type": "null"}, - ], - "nullable": True, - "title": "Reporter", - }, - "participants": { - "items": {"$ref": "#/$defs/Team_ip4pg6_leaf"}, - "title": "Participants", - "type": "array", - }, - "modified": { - "format": "date-time", - "readOnly": True, - "title": "Modified", - "type": "string", - }, - "token": { - "anyOf": [{"type": "string"}, {"type": "null"}], - "title": "Token", - }, - "alias": { - "anyOf": [ - { - "maximum": 2147483647, - "minimum": -2147483648, - "type": "integer", - }, - {"type": "null"}, - ], - "default": None, - "nullable": True, - "title": "Alias", - }, - "address": { - "anyOf": [ - {"$ref": "#/$defs/Address_e4rhju_leaf"}, - {"type": "null"}, - ], - "nullable": True, - "title": "Address", - }, - }, - "required": [ - "event_id", - "name", - "tournament", - "reporter", - "participants", - "modified", - "token", - "address", - ], - "title": "Event", - "type": "object", - }, - "Address_e4rhju_leaf": { - "additionalProperties": False, - "properties": { - "city": {"maxLength": 64, "title": "City", "type": "string"}, - "street": {"maxLength": 128, "title": "Street", "type": "string"}, - "m2mwitho2opks": { - "items": {"$ref": "#/$defs/M2mWithO2oPk_leajz6_leaf"}, - "title": "M2Mwitho2Opks", - "type": "array", - }, - "event_id": { - "maximum": 9223372036854775807, - "minimum": -9223372036854775808, - "title": "Event Id", - "type": "integer", - }, - }, - "required": ["city", "street", "event_id", "m2mwitho2opks"], - "title": "Address", - "type": "object", - }, - "M2mWithO2oPk_leajz6_leaf": { - "additionalProperties": False, - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"maxLength": 64, "title": "Name", "type": "string"}, - }, - "required": ["id", "name"], - "title": "M2mWithO2oPk", - "type": "object", - }, - "Reporter_fgnv33_leaf": { - "additionalProperties": False, - "description": "Whom is assigned as the reporter", - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"title": "Name", "type": "string"}, - }, - "required": ["id", "name"], - "title": "Reporter", - "type": "object", - }, - "Team_ip4pg6_leaf": { - "additionalProperties": False, - "description": "Team that is a playing", - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"title": "Name", "type": "string"}, - "alias": { - "anyOf": [ - { - "maximum": 2147483647, - "minimum": -2147483648, - "type": "integer", - }, - {"type": "null"}, - ], - "default": None, - "nullable": True, - "title": "Alias", - }, - }, - "required": ["id", "name"], - "title": "Team", - "type": "object", - }, - "Tournament_5y7e7j_leaf": { - "additionalProperties": False, - "properties": { - "id": { - "maximum": 32767, - "minimum": -32768, - "title": "Id", - "type": "integer", - }, - "name": {"maxLength": 255, "title": "Name", "type": "string"}, - "desc": { - "anyOf": [{"type": "string"}, {"type": "null"}], - "default": None, - "nullable": True, - "title": "Desc", - }, - "created": { - "format": "date-time", - "readOnly": True, - "title": "Created", - "type": "string", - }, - }, - "required": ["id", "name", "created"], - "title": "Tournament", - "type": "object", - }, - }, - "description": "Events on the calendar", - "items": {"$ref": "#/$defs/Event_mfxmwb"}, - "title": "Event_list", - "type": "array", - }, - ) - - def test_address_schema(self): - self.assertEqual( - self.Address_Pydantic.model_json_schema(), - { - "$defs": { - "Event_zvunzw_leaf": { - "additionalProperties": False, - "description": "Events on the calendar", - "properties": { - "event_id": { - "maximum": 9223372036854775807, - "minimum": -9223372036854775808, - "title": "Event Id", - "type": "integer", - }, - "name": {"description": "The name", "title": "Name", "type": "string"}, - "tournament": { - "$ref": "#/$defs/Tournament_5y7e7j_leaf", - "description": "What tournaments is a happenin'", - }, - "reporter": { - "anyOf": [ - {"$ref": "#/$defs/Reporter_fgnv33_leaf"}, - {"type": "null"}, - ], - "nullable": True, - "title": "Reporter", - }, - "participants": { - "items": {"$ref": "#/$defs/Team_ip4pg6_leaf"}, - "title": "Participants", - "type": "array", - }, - "modified": { - "format": "date-time", - "readOnly": True, - "title": "Modified", - "type": "string", - }, - "token": { - "anyOf": [{"type": "string"}, {"type": "null"}], - "title": "Token", - }, - "alias": { - "anyOf": [ - { - "maximum": 2147483647, - "minimum": -2147483648, - "type": "integer", - }, - {"type": "null"}, - ], - "default": None, - "nullable": True, - "title": "Alias", - }, - }, - "required": [ - "event_id", - "name", - "tournament", - "reporter", - "participants", - "modified", - "token", - ], - "title": "Event", - "type": "object", - }, - "M2mWithO2oPk_leajz6_leaf": { - "additionalProperties": False, - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"maxLength": 64, "title": "Name", "type": "string"}, - }, - "required": ["id", "name"], - "title": "M2mWithO2oPk", - "type": "object", - }, - "Reporter_fgnv33_leaf": { - "additionalProperties": False, - "description": "Whom is assigned as the reporter", - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"title": "Name", "type": "string"}, - }, - "required": ["id", "name"], - "title": "Reporter", - "type": "object", - }, - "Team_ip4pg6_leaf": { - "additionalProperties": False, - "description": "Team that is a playing", - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"title": "Name", "type": "string"}, - "alias": { - "anyOf": [ - { - "maximum": 2147483647, - "minimum": -2147483648, - "type": "integer", - }, - {"type": "null"}, - ], - "default": None, - "nullable": True, - "title": "Alias", - }, - }, - "required": ["id", "name"], - "title": "Team", - "type": "object", - }, - "Tournament_5y7e7j_leaf": { - "additionalProperties": False, - "properties": { - "id": { - "maximum": 32767, - "minimum": -32768, - "title": "Id", - "type": "integer", - }, - "name": {"maxLength": 255, "title": "Name", "type": "string"}, - "desc": { - "anyOf": [{"type": "string"}, {"type": "null"}], - "default": None, - "nullable": True, - "title": "Desc", - }, - "created": { - "format": "date-time", - "readOnly": True, - "title": "Created", - "type": "string", - }, - }, - "required": ["id", "name", "created"], - "title": "Tournament", - "type": "object", - }, - }, + "Address_e4rhju_leaf": { "additionalProperties": False, "properties": { "city": {"maxLength": 64, "title": "City", "type": "string"}, @@ -628,7 +383,6 @@ def test_address_schema(self): "title": "M2Mwitho2Opks", "type": "array", }, - "event": {"$ref": "#/$defs/Event_zvunzw_leaf"}, "event_id": { "maximum": 9223372036854775807, "minimum": -9223372036854775808, @@ -636,175 +390,84 @@ def test_address_schema(self): "type": "integer", }, }, - "required": ["city", "street", "event", "event_id", "m2mwitho2opks"], + "required": ["city", "street", "event_id", "m2mwitho2opks"], "title": "Address", "type": "object", }, - ) - - def test_tournament_schema(self): - self.assertEqual( - self.Tournament_Pydantic.model_json_schema(), - { - "$defs": { - "Event_ln6p2q_leaf": { - "additionalProperties": False, - "description": "Events on the calendar", - "properties": { - "event_id": { - "maximum": 9223372036854775807, - "minimum": -9223372036854775808, - "title": "Event Id", - "type": "integer", - }, - "name": {"description": "The name", "title": "Name", "type": "string"}, - "reporter": { - "anyOf": [ - {"$ref": "#/$defs/Reporter_fgnv33_leaf"}, - {"type": "null"}, - ], - "nullable": True, - "title": "Reporter", - }, - "participants": { - "items": {"$ref": "#/$defs/Team_ip4pg6_leaf"}, - "title": "Participants", - "type": "array", - }, - "modified": { - "format": "date-time", - "readOnly": True, - "title": "Modified", - "type": "string", - }, - "token": { - "anyOf": [{"type": "string"}, {"type": "null"}], - "title": "Token", - }, - "alias": { - "anyOf": [ - { - "maximum": 2147483647, - "minimum": -2147483648, - "type": "integer", - }, - {"type": "null"}, - ], - "default": None, - "nullable": True, - "title": "Alias", - }, - "address": { - "anyOf": [ - {"$ref": "#/$defs/Address_e4rhju_leaf"}, - {"type": "null"}, - ], - "nullable": True, - "title": "Address", - }, - }, - "required": [ - "event_id", - "name", - "reporter", - "participants", - "modified", - "token", - "address", - ], - "title": "Event", - "type": "object", - }, - "Address_e4rhju_leaf": { - "additionalProperties": False, - "properties": { - "city": {"maxLength": 64, "title": "City", "type": "string"}, - "street": {"maxLength": 128, "title": "Street", "type": "string"}, - "m2mwitho2opks": { - "items": {"$ref": "#/$defs/M2mWithO2oPk_leajz6_leaf"}, - "title": "M2Mwitho2Opks", - "type": "array", - }, - "event_id": { - "maximum": 9223372036854775807, - "minimum": -9223372036854775808, - "title": "Event Id", - "type": "integer", - }, - }, - "required": ["city", "street", "event_id", "m2mwitho2opks"], - "title": "Address", - "type": "object", - }, - "M2mWithO2oPk_leajz6_leaf": { - "additionalProperties": False, - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"maxLength": 64, "title": "Name", "type": "string"}, - }, - "required": ["id", "name"], - "title": "M2mWithO2oPk", - "type": "object", - }, - "Reporter_fgnv33_leaf": { - "additionalProperties": False, - "description": "Whom is assigned as the reporter", - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"title": "Name", "type": "string"}, - }, - "required": ["id", "name"], - "title": "Reporter", - "type": "object", - }, - "Team_ip4pg6_leaf": { - "additionalProperties": False, - "description": "Team that is a playing", - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"title": "Name", "type": "string"}, - "alias": { - "anyOf": [ - { - "maximum": 2147483647, - "minimum": -2147483648, - "type": "integer", - }, - {"type": "null"}, - ], - "default": None, - "nullable": True, - "title": "Alias", - }, - }, - "required": ["id", "name"], - "title": "Team", - "type": "object", + "M2mWithO2oPk_leajz6_leaf": { + "additionalProperties": False, + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", }, + "name": {"maxLength": 64, "title": "Name", "type": "string"}, }, + "required": ["id", "name"], + "title": "M2mWithO2oPk", + "type": "object", + }, + "Reporter_fgnv33_leaf": { "additionalProperties": False, + "description": "Whom is assigned as the reporter", "properties": { - "id": {"maximum": 32767, "minimum": -32768, "title": "Id", "type": "integer"}, - "name": {"maxLength": 255, "title": "Name", "type": "string"}, - "desc": { - "anyOf": [{"type": "string"}, {"type": "null"}], - "default": None, - "nullable": True, + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"title": "Name", "type": "string"}, + }, + "required": ["id", "name"], + "title": "Reporter", + "type": "object", + }, + "Team_ip4pg6_leaf": { + "additionalProperties": False, + "description": "Team that is a playing", + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"title": "Name", "type": "string"}, + "alias": { + "anyOf": [ + { + "maximum": 2147483647, + "minimum": -2147483648, + "type": "integer", + }, + {"type": "null"}, + ], + "default": None, + "nullable": True, + "title": "Alias", + }, + }, + "required": ["id", "name"], + "title": "Team", + "type": "object", + }, + "Tournament_5y7e7j_leaf": { + "additionalProperties": False, + "properties": { + "id": { + "maximum": 32767, + "minimum": -32768, + "title": "Id", + "type": "integer", + }, + "name": {"maxLength": 255, "title": "Name", "type": "string"}, + "desc": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "nullable": True, "title": "Desc", }, "created": { @@ -813,171 +476,120 @@ def test_tournament_schema(self): "title": "Created", "type": "string", }, - "events": { - "description": "What tournaments is a happenin'", - "items": {"$ref": "#/$defs/Event_ln6p2q_leaf"}, - "title": "Events", - "type": "array", - }, }, - "required": ["id", "name", "created", "events"], + "required": ["id", "name", "created"], "title": "Tournament", "type": "object", }, - ) - - def test_team_schema(self): - self.assertEqual( - self.Team_Pydantic.model_json_schema(), - { - "$defs": { - "Event_lfs4vy_leaf": { - "additionalProperties": False, - "description": "Events on the calendar", - "properties": { - "event_id": { - "maximum": 9223372036854775807, - "minimum": -9223372036854775808, - "title": "Event Id", - "type": "integer", - }, - "name": {"description": "The name", "title": "Name", "type": "string"}, - "tournament": { - "$ref": "#/$defs/Tournament_5y7e7j_leaf", - "description": "What tournaments is a happenin'", - }, - "reporter": { - "anyOf": [ - {"$ref": "#/$defs/Reporter_fgnv33_leaf"}, - {"type": "null"}, - ], - "nullable": True, - "title": "Reporter", - }, - "modified": { - "format": "date-time", - "readOnly": True, - "title": "Modified", - "type": "string", - }, - "token": { - "anyOf": [{"type": "string"}, {"type": "null"}], - "title": "Token", - }, - "alias": { - "anyOf": [ - { - "maximum": 2147483647, - "minimum": -2147483648, - "type": "integer", - }, - {"type": "null"}, - ], - "default": None, - "nullable": True, - "title": "Alias", - }, - "address": { - "anyOf": [ - {"$ref": "#/$defs/Address_e4rhju_leaf"}, - {"type": "null"}, - ], - "nullable": True, - "title": "Address", - }, - }, - "required": [ - "event_id", - "name", - "tournament", - "reporter", - "modified", - "token", - "address", + }, + "description": "Events on the calendar", + "items": {"$ref": "#/$defs/Event_mfxmwb"}, + "title": "Event_list", + "type": "array", + } + + +@pytest.mark.asyncio +async def test_address_schema(db, pydantic_setup): + Address_Pydantic = pydantic_setup["Address_Pydantic"] + assert Address_Pydantic.model_json_schema() == { + "$defs": { + "Event_zvunzw_leaf": { + "additionalProperties": False, + "description": "Events on the calendar", + "properties": { + "event_id": { + "maximum": 9223372036854775807, + "minimum": -9223372036854775808, + "title": "Event Id", + "type": "integer", + }, + "name": {"description": "The name", "title": "Name", "type": "string"}, + "tournament": { + "$ref": "#/$defs/Tournament_5y7e7j_leaf", + "description": "What tournaments is a happenin'", + }, + "reporter": { + "anyOf": [ + {"$ref": "#/$defs/Reporter_fgnv33_leaf"}, + {"type": "null"}, ], - "title": "Event", - "type": "object", - }, - "Address_e4rhju_leaf": { - "additionalProperties": False, - "properties": { - "city": {"maxLength": 64, "title": "City", "type": "string"}, - "street": {"maxLength": 128, "title": "Street", "type": "string"}, - "m2mwitho2opks": { - "items": {"$ref": "#/$defs/M2mWithO2oPk_leajz6_leaf"}, - "title": "M2Mwitho2Opks", - "type": "array", - }, - "event_id": { - "maximum": 9223372036854775807, - "minimum": -9223372036854775808, - "title": "Event Id", - "type": "integer", - }, - }, - "required": ["city", "street", "event_id", "m2mwitho2opks"], - "title": "Address", - "type": "object", + "nullable": True, + "title": "Reporter", }, - "M2mWithO2oPk_leajz6_leaf": { - "additionalProperties": False, - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"maxLength": 64, "title": "Name", "type": "string"}, - }, - "required": ["id", "name"], - "title": "M2mWithO2oPk", - "type": "object", - }, - "Reporter_fgnv33_leaf": { - "additionalProperties": False, - "description": "Whom is assigned as the reporter", - "properties": { - "id": { + "participants": { + "items": {"$ref": "#/$defs/Team_ip4pg6_leaf"}, + "title": "Participants", + "type": "array", + }, + "modified": { + "format": "date-time", + "readOnly": True, + "title": "Modified", + "type": "string", + }, + "token": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "Token", + }, + "alias": { + "anyOf": [ + { "maximum": 2147483647, "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"title": "Name", "type": "string"}, - }, - "required": ["id", "name"], - "title": "Reporter", - "type": "object", - }, - "Tournament_5y7e7j_leaf": { - "additionalProperties": False, - "properties": { - "id": { - "maximum": 32767, - "minimum": -32768, - "title": "Id", "type": "integer", }, - "name": {"maxLength": 255, "title": "Name", "type": "string"}, - "desc": { - "anyOf": [{"type": "string"}, {"type": "null"}], - "default": None, - "nullable": True, - "title": "Desc", - }, - "created": { - "format": "date-time", - "readOnly": True, - "title": "Created", - "type": "string", - }, - }, - "required": ["id", "name", "created"], - "title": "Tournament", - "type": "object", + {"type": "null"}, + ], + "default": None, + "nullable": True, + "title": "Alias", + }, + }, + "required": [ + "event_id", + "name", + "tournament", + "reporter", + "participants", + "modified", + "token", + ], + "title": "Event", + "type": "object", + }, + "M2mWithO2oPk_leajz6_leaf": { + "additionalProperties": False, + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"maxLength": 64, "title": "Name", "type": "string"}, + }, + "required": ["id", "name"], + "title": "M2mWithO2oPk", + "type": "object", + }, + "Reporter_fgnv33_leaf": { + "additionalProperties": False, + "description": "Whom is assigned as the reporter", + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", }, + "name": {"title": "Name", "type": "string"}, }, + "required": ["id", "name"], + "title": "Reporter", + "type": "object", + }, + "Team_ip4pg6_leaf": { "additionalProperties": False, "description": "Team that is a playing", "properties": { @@ -990,496 +602,895 @@ def test_team_schema(self): "name": {"title": "Name", "type": "string"}, "alias": { "anyOf": [ - {"maximum": 2147483647, "minimum": -2147483648, "type": "integer"}, + { + "maximum": 2147483647, + "minimum": -2147483648, + "type": "integer", + }, {"type": "null"}, ], "default": None, "nullable": True, "title": "Alias", }, - "events": { - "items": {"$ref": "#/$defs/Event_lfs4vy_leaf"}, - "title": "Events", - "type": "array", - }, }, - "required": ["id", "name", "events"], + "required": ["id", "name"], "title": "Team", "type": "object", }, - ) - - async def test_eventlist(self): - eventlp = await self.Event_Pydantic_List.from_queryset(Event.all()) - eventldict = eventlp.model_dump() - - # Remove timestamps - del eventldict[0]["modified"] - del eventldict[0]["tournament"]["created"] - del eventldict[1]["modified"] - del eventldict[1]["tournament"]["created"] - - self.assertEqual( - eventldict, - [ - { - "event_id": self.event.event_id, - "name": "Test", - # "modified": "2020-01-28T10:43:50.901562", - "token": self.event.token, - "alias": None, - "tournament": { - "id": self.tournament.id, - "name": "New Tournament", - "desc": None, - # "created": "2020-01-28T10:43:50.900664" - }, - "reporter": {"id": self.reporter.id, "name": "The Reporter"}, - "participants": [ - {"id": self.team1.id, "name": "Onesies", "alias": None}, - {"id": self.team2.id, "name": "T-Shirts", "alias": None}, - ], - "address": { - "event_id": self.address.pk, - "city": "Santa Monica", - "m2mwitho2opks": [], - "street": "Ocean", + "Tournament_5y7e7j_leaf": { + "additionalProperties": False, + "properties": { + "id": { + "maximum": 32767, + "minimum": -32768, + "title": "Id", + "type": "integer", + }, + "name": {"maxLength": 255, "title": "Name", "type": "string"}, + "desc": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "nullable": True, + "title": "Desc", + }, + "created": { + "format": "date-time", + "readOnly": True, + "title": "Created", + "type": "string", }, }, - { - "event_id": self.event2.event_id, - "name": "Test2", - # "modified": "2020-01-28T10:43:50.901562", - "token": self.event2.token, - "alias": None, - "tournament": { - "id": self.tournament.id, - "name": "New Tournament", - "desc": None, - # "created": "2020-01-28T10:43:50.900664" - }, - "reporter": None, - "participants": [ - {"id": self.team1.id, "name": "Onesies", "alias": None}, - {"id": self.team2.id, "name": "T-Shirts", "alias": None}, - ], - "address": None, - }, - ], - ) - - async def test_event(self): - eventp = await self.Event_Pydantic.from_tortoise_orm(await Event.get(name="Test")) - eventdict = eventp.model_dump() - - # Remove timestamps - del eventdict["modified"] - del eventdict["tournament"]["created"] - - self.assertEqual( - eventdict, - { - "event_id": self.event.event_id, - "name": "Test", - # "modified": "2020-01-28T10:43:50.901562", - "token": self.event.token, - "alias": None, - "tournament": { - "id": self.tournament.id, - "name": "New Tournament", - "desc": None, - # "created": "2020-01-28T10:43:50.900664" - }, - "reporter": {"id": self.reporter.id, "name": "The Reporter"}, - "participants": [ - {"id": self.team1.id, "name": "Onesies", "alias": None}, - {"id": self.team2.id, "name": "T-Shirts", "alias": None}, - ], - "address": { - "event_id": self.address.pk, - "city": "Santa Monica", - "m2mwitho2opks": [], - "street": "Ocean", - }, + "required": ["id", "name", "created"], + "title": "Tournament", + "type": "object", }, - ) - - async def test_address(self): - addressp = await self.Address_Pydantic.from_tortoise_orm(await Address.get(street="Ocean")) - addressdict = addressp.model_dump() - - # Remove timestamps - del addressdict["event"]["tournament"]["created"] - del addressdict["event"]["modified"] - - self.assertEqual( - addressdict, - { - "city": "Santa Monica", - "street": "Ocean", - "event": { - "event_id": self.event.event_id, - "name": "Test", - "tournament": { - "id": self.tournament.id, - "name": "New Tournament", - "desc": None, - }, - "reporter": {"id": self.reporter.id, "name": "The Reporter"}, - "participants": [ - {"id": self.team1.id, "name": "Onesies", "alias": None}, - {"id": self.team2.id, "name": "T-Shirts", "alias": None}, - ], - "token": self.event.token, - "alias": None, - }, - "event_id": self.address.event_id, - "m2mwitho2opks": [], + }, + "additionalProperties": False, + "properties": { + "city": {"maxLength": 64, "title": "City", "type": "string"}, + "street": {"maxLength": 128, "title": "Street", "type": "string"}, + "m2mwitho2opks": { + "items": {"$ref": "#/$defs/M2mWithO2oPk_leajz6_leaf"}, + "title": "M2Mwitho2Opks", + "type": "array", }, - ) - - async def test_tournament(self): - tournamentp = await self.Tournament_Pydantic.from_tortoise_orm( - await Tournament.all().first() - ) - tournamentdict = tournamentp.model_dump() - - # Remove timestamps - del tournamentdict["events"][0]["modified"] - del tournamentdict["events"][1]["modified"] - del tournamentdict["created"] - - self.assertEqual( - tournamentdict, - { - "id": self.tournament.id, - "name": "New Tournament", - "desc": None, - # "created": "2020-01-28T19:41:38.059617", - "events": [ - { - "event_id": self.event.event_id, - "name": "Test", - # "modified": "2020-01-28T19:41:38.060070", - "token": self.event.token, - "alias": None, - "reporter": {"id": self.reporter.id, "name": "The Reporter"}, - "participants": [ - {"id": self.team1.id, "name": "Onesies", "alias": None}, - {"id": self.team2.id, "name": "T-Shirts", "alias": None}, - ], - "address": { - "event_id": self.address.pk, - "city": "Santa Monica", - "m2mwitho2opks": [], - "street": "Ocean", - }, + "event": {"$ref": "#/$defs/Event_zvunzw_leaf"}, + "event_id": { + "maximum": 9223372036854775807, + "minimum": -9223372036854775808, + "title": "Event Id", + "type": "integer", + }, + }, + "required": ["city", "street", "event", "event_id", "m2mwitho2opks"], + "title": "Address", + "type": "object", + } + + +@pytest.mark.asyncio +async def test_tournament_schema(db, pydantic_setup): + Tournament_Pydantic = pydantic_setup["Tournament_Pydantic"] + assert Tournament_Pydantic.model_json_schema() == { + "$defs": { + "Event_ln6p2q_leaf": { + "additionalProperties": False, + "description": "Events on the calendar", + "properties": { + "event_id": { + "maximum": 9223372036854775807, + "minimum": -9223372036854775808, + "title": "Event Id", + "type": "integer", }, - { - "event_id": self.event2.event_id, - "name": "Test2", - # "modified": "2020-01-28T19:41:38.060070", - "token": self.event2.token, - "alias": None, - "reporter": None, - "participants": [ - {"id": self.team1.id, "name": "Onesies", "alias": None}, - {"id": self.team2.id, "name": "T-Shirts", "alias": None}, + "name": {"description": "The name", "title": "Name", "type": "string"}, + "reporter": { + "anyOf": [ + {"$ref": "#/$defs/Reporter_fgnv33_leaf"}, + {"type": "null"}, ], - "address": None, - }, - ], - }, - ) - - async def test_team(self): - teamp = await self.Team_Pydantic.from_tortoise_orm(await Team.get(id=self.team1.id)) - teamdict = teamp.model_dump() - - # Remove timestamps - del teamdict["events"][0]["modified"] - del teamdict["events"][0]["tournament"]["created"] - del teamdict["events"][1]["modified"] - del teamdict["events"][1]["tournament"]["created"] - - self.assertEqual( - teamdict, - { - "id": self.team1.id, - "name": "Onesies", - "alias": None, - "events": [ - { - "event_id": self.event.event_id, - "name": "Test", - # "modified": "2020-01-28T19:47:03.334077", - "token": self.event.token, - "alias": None, - "tournament": { - "id": self.tournament.id, - "name": "New Tournament", - "desc": None, - # "created": "2020-01-28T19:41:38.059617", - }, - "reporter": {"id": self.reporter.id, "name": "The Reporter"}, - "address": { - "event_id": self.address.pk, - "city": "Santa Monica", - "m2mwitho2opks": [], - "street": "Ocean", - }, + "nullable": True, + "title": "Reporter", }, - { - "event_id": self.event2.event_id, - "name": "Test2", - # "modified": "2020-01-28T19:47:03.334077", - "token": self.event2.token, - "alias": None, - "tournament": { - "id": self.tournament.id, - "name": "New Tournament", - "desc": None, - # "created": "2020-01-28T19:41:38.059617", - }, - "reporter": None, - "address": None, + "participants": { + "items": {"$ref": "#/$defs/Team_ip4pg6_leaf"}, + "title": "Participants", + "type": "array", }, + "modified": { + "format": "date-time", + "readOnly": True, + "title": "Modified", + "type": "string", + }, + "token": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "Token", + }, + "alias": { + "anyOf": [ + { + "maximum": 2147483647, + "minimum": -2147483648, + "type": "integer", + }, + {"type": "null"}, + ], + "default": None, + "nullable": True, + "title": "Alias", + }, + "address": { + "anyOf": [ + {"$ref": "#/$defs/Address_e4rhju_leaf"}, + {"type": "null"}, + ], + "nullable": True, + "title": "Address", + }, + }, + "required": [ + "event_id", + "name", + "reporter", + "participants", + "modified", + "token", + "address", ], + "title": "Event", + "type": "object", }, - ) - - def test_event_named(self): - Event_Named = pydantic_model_creator(Event, name="Foo") - schema = Event_Named.model_json_schema() - self.assertEqual(schema["title"], "Foo") - self.assertSetEqual( - set(schema["properties"].keys()), - { - "address", - "alias", - "event_id", - "modified", - "name", - "participants", - "reporter", - "token", - "tournament", + "Address_e4rhju_leaf": { + "additionalProperties": False, + "properties": { + "city": {"maxLength": 64, "title": "City", "type": "string"}, + "street": {"maxLength": 128, "title": "Street", "type": "string"}, + "m2mwitho2opks": { + "items": {"$ref": "#/$defs/M2mWithO2oPk_leajz6_leaf"}, + "title": "M2Mwitho2Opks", + "type": "array", + }, + "event_id": { + "maximum": 9223372036854775807, + "minimum": -9223372036854775808, + "title": "Event Id", + "type": "integer", + }, + }, + "required": ["city", "street", "event_id", "m2mwitho2opks"], + "title": "Address", + "type": "object", }, - ) - - def test_event_sorted(self): - Event_Named = pydantic_model_creator(Event, sort_alphabetically=True) - schema = Event_Named.model_json_schema() - self.assertEqual( - list(schema["properties"].keys()), - [ - "address", - "alias", - "event_id", - "modified", - "name", - "participants", - "reporter", - "token", - "tournament", - ], - ) - - def test_event_unsorted(self): - Event_Named = pydantic_model_creator(Event, sort_alphabetically=False) - schema = Event_Named.model_json_schema() - self.assertEqual( - list(schema["properties"].keys()), - [ - "event_id", - "name", - "tournament", - "reporter", - "participants", - "modified", - "token", - "alias", - "address", - ], - ) - - async def test_json_field(self): - json_field_0 = await JSONFields.create(data={"a": 1}) - json_field_1 = await JSONFields.create(data=[{"a": 1, "b": 2}]) - json_field_0_get = await JSONFields.get(pk=json_field_0.pk) - json_field_1_get = await JSONFields.get(pk=json_field_1.pk) - - creator = pydantic_model_creator(JSONFields) - ret0 = creator.model_validate(json_field_0_get).model_dump() - self.assertEqual( - ret0, - { - "id": json_field_0.pk, - "data": {"a": 1}, - "data_null": None, - "data_default": {"a": 1}, - "data_validate": None, - "data_pydantic": json_pydantic_default.model_dump(), + "M2mWithO2oPk_leajz6_leaf": { + "additionalProperties": False, + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"maxLength": 64, "title": "Name", "type": "string"}, + }, + "required": ["id", "name"], + "title": "M2mWithO2oPk", + "type": "object", }, - ) - ret1 = creator.model_validate(json_field_1_get).model_dump() - self.assertEqual( - ret1, - { - "id": json_field_1.pk, - "data": [{"a": 1, "b": 2}], - "data_null": None, - "data_default": {"a": 1}, - "data_validate": None, - "data_pydantic": json_pydantic_default.model_dump(), + "Reporter_fgnv33_leaf": { + "additionalProperties": False, + "description": "Whom is assigned as the reporter", + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"title": "Name", "type": "string"}, + }, + "required": ["id", "name"], + "title": "Reporter", + "type": "object", }, - ) - - def test_override_default_model_config_by_config_class(self): - """Pydantic meta's config_class should be able to override default config.""" - CamelCaseAliasPersonCopy = copy.deepcopy(CamelCaseAliasPerson) - # Set class pydantic config's orm_mode to False - CamelCaseAliasPersonCopy.PydanticMeta.model_config["from_attributes"] = False - - ModelPydantic = pydantic_model_creator( - CamelCaseAliasPerson, name="AutoAliasPersonOverriddenORMMode" - ) - - self.assertEqual(ModelPydantic.model_config["from_attributes"], False) - - def test_override_meta_pydantic_config_by_model_creator(self): - model_config = ConfigDict(title="Another title!") - - ModelPydantic = pydantic_model_creator( - CamelCaseAliasPerson, - model_config=model_config, - name="AutoAliasPersonModelCreatorConfig", - ) - - self.assertEqual(model_config["title"], ModelPydantic.model_config["title"]) - - def test_config_classes_merge_all_configs(self): - """Model creator should merge all 3 configs. - - - It merges (Default, Meta's config_class and creator's config_class) together. - """ - model_config = ConfigDict(str_min_length=3) - - ModelPydantic = pydantic_model_creator( - CamelCaseAliasPerson, name="AutoAliasPersonMinLength", model_config=model_config - ) - - # Should set min_anystr_length from pydantic_model_creator's config - self.assertEqual( - ModelPydantic.model_config["str_min_length"], model_config["str_min_length"] - ) - # Should set title from model PydanticMeta's config - self.assertEqual( - ModelPydantic.model_config["title"], - CamelCaseAliasPerson.PydanticMeta.model_config["title"], - ) - # Should set orm_mode from base pydantic model configuration - self.assertEqual( - ModelPydantic.model_config["from_attributes"], - PydanticModel.model_config["from_attributes"], - ) - - def test_exclude_readonly(self): - ModelPydantic = pydantic_model_creator(Event, exclude_readonly=True) - - self.assertNotIn("modified", ModelPydantic.model_json_schema()["properties"]) - - -class TestPydanticCycle(test.TestCase): - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - self.Employee_Pydantic = pydantic_model_creator(Employee) - - self.root = await Employee.create(name="Root") - self.loose = await Employee.create(name="Loose") - self._1 = await Employee.create(name="1. First H1", manager=self.root) - self._2 = await Employee.create(name="2. Second H1", manager=self.root) - self._1_1 = await Employee.create(name="1.1. First H2", manager=self._1) - self._1_1_1 = await Employee.create(name="1.1.1. First H3", manager=self._1_1) - self._2_1 = await Employee.create(name="2.1. Second H2", manager=self._2) - self._2_2 = await Employee.create(name="2.2. Third H2", manager=self._2) - - await self._1.talks_to.add(self._2, self._1_1_1, self.loose) - await self._2_1.gets_talked_to.add(self._2_2, self._1_1, self.loose) - self.maxDiff = None - - def test_schema(self): - self.assertEqual( - self.Employee_Pydantic.model_json_schema(), - { - "$defs": { - "Employee_6tkbjb_leaf": { - "additionalProperties": False, - "properties": { - "id": { + "Team_ip4pg6_leaf": { + "additionalProperties": False, + "description": "Team that is a playing", + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"title": "Name", "type": "string"}, + "alias": { + "anyOf": [ + { "maximum": 2147483647, "minimum": -2147483648, - "title": "Id", "type": "integer", }, - "name": {"maxLength": 50, "title": "Name", "type": "string"}, - "talks_to": { - "items": {"$ref": "#/$defs/Employee_fj2ly4_leaf"}, - "title": "Talks To", - "type": "array", - }, - "manager_id": { - "anyOf": [ - { - "maximum": 2147483647, - "minimum": -2147483648, - "type": "integer", - }, - {"type": "null"}, - ], - "default": None, - "nullable": True, - "title": "Manager Id", - }, - "team_members": { - "items": {"$ref": "#/$defs/Employee_fj2ly4_leaf"}, - "title": "Team Members", - "type": "array", - }, - }, - "required": ["id", "name", "talks_to", "team_members"], - "title": "Employee", - "type": "object", - }, - "Employee_fj2ly4_leaf": { - "additionalProperties": False, - "properties": { - "id": { + {"type": "null"}, + ], + "default": None, + "nullable": True, + "title": "Alias", + }, + }, + "required": ["id", "name"], + "title": "Team", + "type": "object", + }, + }, + "additionalProperties": False, + "properties": { + "id": {"maximum": 32767, "minimum": -32768, "title": "Id", "type": "integer"}, + "name": {"maxLength": 255, "title": "Name", "type": "string"}, + "desc": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "nullable": True, + "title": "Desc", + }, + "created": { + "format": "date-time", + "readOnly": True, + "title": "Created", + "type": "string", + }, + "events": { + "description": "What tournaments is a happenin'", + "items": {"$ref": "#/$defs/Event_ln6p2q_leaf"}, + "title": "Events", + "type": "array", + }, + }, + "required": ["id", "name", "created", "events"], + "title": "Tournament", + "type": "object", + } + + +@pytest.mark.asyncio +async def test_team_schema(db, pydantic_setup): + Team_Pydantic = pydantic_setup["Team_Pydantic"] + assert Team_Pydantic.model_json_schema() == { + "$defs": { + "Event_lfs4vy_leaf": { + "additionalProperties": False, + "description": "Events on the calendar", + "properties": { + "event_id": { + "maximum": 9223372036854775807, + "minimum": -9223372036854775808, + "title": "Event Id", + "type": "integer", + }, + "name": {"description": "The name", "title": "Name", "type": "string"}, + "tournament": { + "$ref": "#/$defs/Tournament_5y7e7j_leaf", + "description": "What tournaments is a happenin'", + }, + "reporter": { + "anyOf": [ + {"$ref": "#/$defs/Reporter_fgnv33_leaf"}, + {"type": "null"}, + ], + "nullable": True, + "title": "Reporter", + }, + "modified": { + "format": "date-time", + "readOnly": True, + "title": "Modified", + "type": "string", + }, + "token": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "Token", + }, + "alias": { + "anyOf": [ + { "maximum": 2147483647, "minimum": -2147483648, - "title": "Id", "type": "integer", }, - "name": {"maxLength": 50, "title": "Name", "type": "string"}, - "manager_id": { - "anyOf": [ - { - "maximum": 2147483647, - "minimum": -2147483648, - "type": "integer", - }, - {"type": "null"}, - ], - "default": None, - "nullable": True, - "title": "Manager Id", - }, - }, - "required": ["id", "name"], - "title": "Employee", - "type": "object", + {"type": "null"}, + ], + "default": None, + "nullable": True, + "title": "Alias", + }, + "address": { + "anyOf": [ + {"$ref": "#/$defs/Address_e4rhju_leaf"}, + {"type": "null"}, + ], + "nullable": True, + "title": "Address", + }, + }, + "required": [ + "event_id", + "name", + "tournament", + "reporter", + "modified", + "token", + "address", + ], + "title": "Event", + "type": "object", + }, + "Address_e4rhju_leaf": { + "additionalProperties": False, + "properties": { + "city": {"maxLength": 64, "title": "City", "type": "string"}, + "street": {"maxLength": 128, "title": "Street", "type": "string"}, + "m2mwitho2opks": { + "items": {"$ref": "#/$defs/M2mWithO2oPk_leajz6_leaf"}, + "title": "M2Mwitho2Opks", + "type": "array", + }, + "event_id": { + "maximum": 9223372036854775807, + "minimum": -9223372036854775808, + "title": "Event Id", + "type": "integer", + }, + }, + "required": ["city", "street", "event_id", "m2mwitho2opks"], + "title": "Address", + "type": "object", + }, + "M2mWithO2oPk_leajz6_leaf": { + "additionalProperties": False, + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"maxLength": 64, "title": "Name", "type": "string"}, + }, + "required": ["id", "name"], + "title": "M2mWithO2oPk", + "type": "object", + }, + "Reporter_fgnv33_leaf": { + "additionalProperties": False, + "description": "Whom is assigned as the reporter", + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", }, + "name": {"title": "Name", "type": "string"}, + }, + "required": ["id", "name"], + "title": "Reporter", + "type": "object", + }, + "Tournament_5y7e7j_leaf": { + "additionalProperties": False, + "properties": { + "id": { + "maximum": 32767, + "minimum": -32768, + "title": "Id", + "type": "integer", + }, + "name": {"maxLength": 255, "title": "Name", "type": "string"}, + "desc": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "nullable": True, + "title": "Desc", + }, + "created": { + "format": "date-time", + "readOnly": True, + "title": "Created", + "type": "string", + }, + }, + "required": ["id", "name", "created"], + "title": "Tournament", + "type": "object", + }, + }, + "additionalProperties": False, + "description": "Team that is a playing", + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"title": "Name", "type": "string"}, + "alias": { + "anyOf": [ + {"maximum": 2147483647, "minimum": -2147483648, "type": "integer"}, + {"type": "null"}, + ], + "default": None, + "nullable": True, + "title": "Alias", + }, + "events": { + "items": {"$ref": "#/$defs/Event_lfs4vy_leaf"}, + "title": "Events", + "type": "array", + }, + }, + "required": ["id", "name", "events"], + "title": "Team", + "type": "object", + } + + +@pytest.mark.asyncio +async def test_eventlist(db, pydantic_setup): + Event_Pydantic_List = pydantic_setup["Event_Pydantic_List"] + event = pydantic_setup["event"] + event2 = pydantic_setup["event2"] + tournament = pydantic_setup["tournament"] + reporter = pydantic_setup["reporter"] + team1 = pydantic_setup["team1"] + team2 = pydantic_setup["team2"] + address = pydantic_setup["address"] + + eventlp = await Event_Pydantic_List.from_queryset(Event.all()) + eventldict = eventlp.model_dump() + + # Remove timestamps + del eventldict[0]["modified"] + del eventldict[0]["tournament"]["created"] + del eventldict[1]["modified"] + del eventldict[1]["tournament"]["created"] + + assert eventldict == [ + { + "event_id": event.event_id, + "name": "Test", + # "modified": "2020-01-28T10:43:50.901562", + "token": event.token, + "alias": None, + "tournament": { + "id": tournament.id, + "name": "New Tournament", + "desc": None, + # "created": "2020-01-28T10:43:50.900664" + }, + "reporter": {"id": reporter.id, "name": "The Reporter"}, + "participants": [ + {"id": team1.id, "name": "Onesies", "alias": None}, + {"id": team2.id, "name": "T-Shirts", "alias": None}, + ], + "address": { + "event_id": address.pk, + "city": "Santa Monica", + "m2mwitho2opks": [], + "street": "Ocean", + }, + }, + { + "event_id": event2.event_id, + "name": "Test2", + # "modified": "2020-01-28T10:43:50.901562", + "token": event2.token, + "alias": None, + "tournament": { + "id": tournament.id, + "name": "New Tournament", + "desc": None, + # "created": "2020-01-28T10:43:50.900664" + }, + "reporter": None, + "participants": [ + {"id": team1.id, "name": "Onesies", "alias": None}, + {"id": team2.id, "name": "T-Shirts", "alias": None}, + ], + "address": None, + }, + ] + + +@pytest.mark.asyncio +async def test_event(db, pydantic_setup): + Event_Pydantic = pydantic_setup["Event_Pydantic"] + event = pydantic_setup["event"] + tournament = pydantic_setup["tournament"] + reporter = pydantic_setup["reporter"] + team1 = pydantic_setup["team1"] + team2 = pydantic_setup["team2"] + address = pydantic_setup["address"] + + eventp = await Event_Pydantic.from_tortoise_orm(await Event.get(name="Test")) + eventdict = eventp.model_dump() + + # Remove timestamps + del eventdict["modified"] + del eventdict["tournament"]["created"] + + assert eventdict == { + "event_id": event.event_id, + "name": "Test", + # "modified": "2020-01-28T10:43:50.901562", + "token": event.token, + "alias": None, + "tournament": { + "id": tournament.id, + "name": "New Tournament", + "desc": None, + # "created": "2020-01-28T10:43:50.900664" + }, + "reporter": {"id": reporter.id, "name": "The Reporter"}, + "participants": [ + {"id": team1.id, "name": "Onesies", "alias": None}, + {"id": team2.id, "name": "T-Shirts", "alias": None}, + ], + "address": { + "event_id": address.pk, + "city": "Santa Monica", + "m2mwitho2opks": [], + "street": "Ocean", + }, + } + + +@pytest.mark.asyncio +async def test_address(db, pydantic_setup): + Address_Pydantic = pydantic_setup["Address_Pydantic"] + event = pydantic_setup["event"] + tournament = pydantic_setup["tournament"] + reporter = pydantic_setup["reporter"] + team1 = pydantic_setup["team1"] + team2 = pydantic_setup["team2"] + address = pydantic_setup["address"] + + addressp = await Address_Pydantic.from_tortoise_orm(await Address.get(street="Ocean")) + addressdict = addressp.model_dump() + + # Remove timestamps + del addressdict["event"]["tournament"]["created"] + del addressdict["event"]["modified"] + + assert addressdict == { + "city": "Santa Monica", + "street": "Ocean", + "event": { + "event_id": event.event_id, + "name": "Test", + "tournament": { + "id": tournament.id, + "name": "New Tournament", + "desc": None, + }, + "reporter": {"id": reporter.id, "name": "The Reporter"}, + "participants": [ + {"id": team1.id, "name": "Onesies", "alias": None}, + {"id": team2.id, "name": "T-Shirts", "alias": None}, + ], + "token": event.token, + "alias": None, + }, + "event_id": address.event_id, + "m2mwitho2opks": [], + } + + +@pytest.mark.asyncio +async def test_tournament(db, pydantic_setup): + Tournament_Pydantic = pydantic_setup["Tournament_Pydantic"] + event = pydantic_setup["event"] + event2 = pydantic_setup["event2"] + tournament = pydantic_setup["tournament"] + reporter = pydantic_setup["reporter"] + team1 = pydantic_setup["team1"] + team2 = pydantic_setup["team2"] + address = pydantic_setup["address"] + + tournamentp = await Tournament_Pydantic.from_tortoise_orm(await Tournament.all().first()) + tournamentdict = tournamentp.model_dump() + + # Remove timestamps + del tournamentdict["events"][0]["modified"] + del tournamentdict["events"][1]["modified"] + del tournamentdict["created"] + + assert tournamentdict == { + "id": tournament.id, + "name": "New Tournament", + "desc": None, + # "created": "2020-01-28T19:41:38.059617", + "events": [ + { + "event_id": event.event_id, + "name": "Test", + # "modified": "2020-01-28T19:41:38.060070", + "token": event.token, + "alias": None, + "reporter": {"id": reporter.id, "name": "The Reporter"}, + "participants": [ + {"id": team1.id, "name": "Onesies", "alias": None}, + {"id": team2.id, "name": "T-Shirts", "alias": None}, + ], + "address": { + "event_id": address.pk, + "city": "Santa Monica", + "m2mwitho2opks": [], + "street": "Ocean", + }, + }, + { + "event_id": event2.event_id, + "name": "Test2", + # "modified": "2020-01-28T19:41:38.060070", + "token": event2.token, + "alias": None, + "reporter": None, + "participants": [ + {"id": team1.id, "name": "Onesies", "alias": None}, + {"id": team2.id, "name": "T-Shirts", "alias": None}, + ], + "address": None, + }, + ], + } + + +@pytest.mark.asyncio +async def test_team(db, pydantic_setup): + Team_Pydantic = pydantic_setup["Team_Pydantic"] + event = pydantic_setup["event"] + event2 = pydantic_setup["event2"] + tournament = pydantic_setup["tournament"] + reporter = pydantic_setup["reporter"] + team1 = pydantic_setup["team1"] + address = pydantic_setup["address"] + + teamp = await Team_Pydantic.from_tortoise_orm(await Team.get(id=team1.id)) + teamdict = teamp.model_dump() + + # Remove timestamps + del teamdict["events"][0]["modified"] + del teamdict["events"][0]["tournament"]["created"] + del teamdict["events"][1]["modified"] + del teamdict["events"][1]["tournament"]["created"] + + assert teamdict == { + "id": team1.id, + "name": "Onesies", + "alias": None, + "events": [ + { + "event_id": event.event_id, + "name": "Test", + # "modified": "2020-01-28T19:47:03.334077", + "token": event.token, + "alias": None, + "tournament": { + "id": tournament.id, + "name": "New Tournament", + "desc": None, + # "created": "2020-01-28T19:41:38.059617", + }, + "reporter": {"id": reporter.id, "name": "The Reporter"}, + "address": { + "event_id": address.pk, + "city": "Santa Monica", + "m2mwitho2opks": [], + "street": "Ocean", + }, + }, + { + "event_id": event2.event_id, + "name": "Test2", + # "modified": "2020-01-28T19:47:03.334077", + "token": event2.token, + "alias": None, + "tournament": { + "id": tournament.id, + "name": "New Tournament", + "desc": None, + # "created": "2020-01-28T19:41:38.059617", }, + "reporter": None, + "address": None, + }, + ], + } + + +@pytest.mark.asyncio +async def test_event_named(db, pydantic_setup): + Event_Named = pydantic_model_creator(Event, name="Foo") + schema = Event_Named.model_json_schema() + assert schema["title"] == "Foo" + assert set(schema["properties"].keys()) == { + "address", + "alias", + "event_id", + "modified", + "name", + "participants", + "reporter", + "token", + "tournament", + } + + +@pytest.mark.asyncio +async def test_event_sorted(db, pydantic_setup): + Event_Named = pydantic_model_creator(Event, sort_alphabetically=True) + schema = Event_Named.model_json_schema() + assert list(schema["properties"].keys()) == [ + "address", + "alias", + "event_id", + "modified", + "name", + "participants", + "reporter", + "token", + "tournament", + ] + + +@pytest.mark.asyncio +async def test_event_unsorted(db, pydantic_setup): + Event_Named = pydantic_model_creator(Event, sort_alphabetically=False) + schema = Event_Named.model_json_schema() + assert list(schema["properties"].keys()) == [ + "event_id", + "name", + "tournament", + "reporter", + "participants", + "modified", + "token", + "alias", + "address", + ] + + +@pytest.mark.asyncio +async def test_json_field(db): + json_field_0 = await JSONFields.create(data={"a": 1}) + json_field_1 = await JSONFields.create(data=[{"a": 1, "b": 2}]) + json_field_0_get = await JSONFields.get(pk=json_field_0.pk) + json_field_1_get = await JSONFields.get(pk=json_field_1.pk) + + creator = pydantic_model_creator(JSONFields) + ret0 = creator.model_validate(json_field_0_get).model_dump() + assert ret0 == { + "id": json_field_0.pk, + "data": {"a": 1}, + "data_null": None, + "data_default": {"a": 1}, + "data_validate": None, + "data_pydantic": json_pydantic_default.model_dump(), + } + ret1 = creator.model_validate(json_field_1_get).model_dump() + assert ret1 == { + "id": json_field_1.pk, + "data": [{"a": 1, "b": 2}], + "data_null": None, + "data_default": {"a": 1}, + "data_validate": None, + "data_pydantic": json_pydantic_default.model_dump(), + } + + +def test_override_default_model_config_by_config_class(db): + """Pydantic meta's config_class should be able to override default config.""" + # Save original value to restore after test + original_value = CamelCaseAliasPerson.PydanticMeta.model_config.get("from_attributes") + try: + # Set class pydantic config's from_attributes to False + CamelCaseAliasPerson.PydanticMeta.model_config["from_attributes"] = False + + ModelPydantic = pydantic_model_creator( + CamelCaseAliasPerson, name="AutoAliasPersonOverriddenORMMode" + ) + + assert ModelPydantic.model_config["from_attributes"] is False + finally: + # Restore original value to avoid polluting other tests + if original_value is None: + CamelCaseAliasPerson.PydanticMeta.model_config.pop("from_attributes", None) + else: + CamelCaseAliasPerson.PydanticMeta.model_config["from_attributes"] = original_value + + +def test_override_meta_pydantic_config_by_model_creator(db): + model_config = ConfigDict(title="Another title!") + + ModelPydantic = pydantic_model_creator( + CamelCaseAliasPerson, + model_config=model_config, + name="AutoAliasPersonModelCreatorConfig", + ) + + assert model_config["title"] == ModelPydantic.model_config["title"] + + +def test_config_classes_merge_all_configs(db): + """Model creator should merge all 3 configs. + + - It merges (Default, Meta's config_class and creator's config_class) together. + """ + model_config = ConfigDict(str_min_length=3) + + ModelPydantic = pydantic_model_creator( + CamelCaseAliasPerson, name="AutoAliasPersonMinLength", model_config=model_config + ) + + # Should set min_anystr_length from pydantic_model_creator's config + assert ModelPydantic.model_config["str_min_length"] == model_config["str_min_length"] + # Should set title from model PydanticMeta's config + assert ( + ModelPydantic.model_config["title"] + == CamelCaseAliasPerson.PydanticMeta.model_config["title"] + ) + # Should set orm_mode from base pydantic model configuration + assert ( + ModelPydantic.model_config["from_attributes"] + == PydanticModel.model_config["from_attributes"] + ) + + +def test_exclude_readonly(db): + ModelPydantic = pydantic_model_creator(Event, exclude_readonly=True) + + assert "modified" not in ModelPydantic.model_json_schema()["properties"] + + +# Fixtures for TestPydanticCycle +@pytest_asyncio.fixture +async def pydantic_cycle_setup(db): + """Setup for pydantic cycle tests with employee hierarchy.""" + Employee_Pydantic = pydantic_model_creator(Employee) + + root = await Employee.create(name="Root") + loose = await Employee.create(name="Loose") + _1 = await Employee.create(name="1. First H1", manager=root) + _2 = await Employee.create(name="2. Second H1", manager=root) + _1_1 = await Employee.create(name="1.1. First H2", manager=_1) + _1_1_1 = await Employee.create(name="1.1.1. First H3", manager=_1_1) + _2_1 = await Employee.create(name="2.1. Second H2", manager=_2) + _2_2 = await Employee.create(name="2.2. Third H2", manager=_2) + + await _1.talks_to.add(_2, _1_1_1, loose) + await _2_1.gets_talked_to.add(_2_2, _1_1, loose) + + return { + "Employee_Pydantic": Employee_Pydantic, + "root": root, + "loose": loose, + "_1": _1, + "_2": _2, + "_1_1": _1_1, + "_1_1_1": _1_1_1, + "_2_1": _2_1, + "_2_2": _2_2, + } + + +@pytest.mark.asyncio +async def test_cycle_schema(db, pydantic_cycle_setup): + Employee_Pydantic = pydantic_cycle_setup["Employee_Pydantic"] + assert Employee_Pydantic.model_json_schema() == { + "$defs": { + "Employee_6tkbjb_leaf": { "additionalProperties": False, "properties": { "id": { @@ -1490,13 +1501,17 @@ def test_schema(self): }, "name": {"maxLength": 50, "title": "Name", "type": "string"}, "talks_to": { - "items": {"$ref": "#/$defs/Employee_6tkbjb_leaf"}, + "items": {"$ref": "#/$defs/Employee_fj2ly4_leaf"}, "title": "Talks To", "type": "array", }, "manager_id": { "anyOf": [ - {"maximum": 2147483647, "minimum": -2147483648, "type": "integer"}, + { + "maximum": 2147483647, + "minimum": -2147483648, + "type": "integer", + }, {"type": "null"}, ], "default": None, @@ -1504,7 +1519,7 @@ def test_schema(self): "title": "Manager Id", }, "team_members": { - "items": {"$ref": "#/$defs/Employee_6tkbjb_leaf"}, + "items": {"$ref": "#/$defs/Employee_fj2ly4_leaf"}, "title": "Team Members", "type": "array", }, @@ -1513,207 +1528,231 @@ def test_schema(self): "title": "Employee", "type": "object", }, - ) - - async def test_serialisation(self): - empp = await self.Employee_Pydantic.from_tortoise_orm(await Employee.get(name="Root")) - empdict = empp.model_dump() - - self.assertEqual( - empdict, - { - "id": self.root.id, - "manager_id": None, - "name": "Root", - "talks_to": [], - "team_members": [ - { - "id": self._1.id, - "manager_id": self.root.id, - "name": "1. First H1", - "talks_to": [ - { - "id": self.loose.id, - "manager_id": None, - "name": "Loose", - "name_length": 5, - "team_size": 0, - }, - { - "id": self._2.id, - "manager_id": self.root.id, - "name": "2. Second H1", - "name_length": 12, - "team_size": 0, - }, + "Employee_fj2ly4_leaf": { + "additionalProperties": False, + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"maxLength": 50, "title": "Name", "type": "string"}, + "manager_id": { + "anyOf": [ { - "id": self._1_1_1.id, - "manager_id": self._1_1.id, - "name": "1.1.1. First H3", - "name_length": 15, - "team_size": 0, + "maximum": 2147483647, + "minimum": -2147483648, + "type": "integer", }, + {"type": "null"}, ], - "team_members": [ - { - "id": self._1_1.id, - "manager_id": self._1.id, - "name": "1.1. First H2", - "name_length": 13, - "team_size": 0, - } - ], - "name_length": 11, - "team_size": 1, + "default": None, + "nullable": True, + "title": "Manager Id", + }, + }, + "required": ["id", "name"], + "title": "Employee", + "type": "object", + }, + }, + "additionalProperties": False, + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"maxLength": 50, "title": "Name", "type": "string"}, + "talks_to": { + "items": {"$ref": "#/$defs/Employee_6tkbjb_leaf"}, + "title": "Talks To", + "type": "array", + }, + "manager_id": { + "anyOf": [ + {"maximum": 2147483647, "minimum": -2147483648, "type": "integer"}, + {"type": "null"}, + ], + "default": None, + "nullable": True, + "title": "Manager Id", + }, + "team_members": { + "items": {"$ref": "#/$defs/Employee_6tkbjb_leaf"}, + "title": "Team Members", + "type": "array", + }, + }, + "required": ["id", "name", "talks_to", "team_members"], + "title": "Employee", + "type": "object", + } + + +@pytest.mark.asyncio +async def test_cycle_serialisation(db, pydantic_cycle_setup): + Employee_Pydantic = pydantic_cycle_setup["Employee_Pydantic"] + root = pydantic_cycle_setup["root"] + loose = pydantic_cycle_setup["loose"] + _1 = pydantic_cycle_setup["_1"] + _2 = pydantic_cycle_setup["_2"] + _1_1 = pydantic_cycle_setup["_1_1"] + _1_1_1 = pydantic_cycle_setup["_1_1_1"] + _2_1 = pydantic_cycle_setup["_2_1"] + _2_2 = pydantic_cycle_setup["_2_2"] + + empp = await Employee_Pydantic.from_tortoise_orm(await Employee.get(name="Root")) + empdict = empp.model_dump() + + assert empdict == { + "id": root.id, + "manager_id": None, + "name": "Root", + "talks_to": [], + "team_members": [ + { + "id": _1.id, + "manager_id": root.id, + "name": "1. First H1", + "talks_to": [ + { + "id": loose.id, + "manager_id": None, + "name": "Loose", + "name_length": 5, + "team_size": 0, }, { - "id": self._2.id, - "manager_id": self.root.id, + "id": _2.id, + "manager_id": root.id, "name": "2. Second H1", - "talks_to": [], - "team_members": [ - { - "id": self._2_1.id, - "manager_id": self._2.id, - "name": "2.1. Second H2", - "name_length": 14, - "team_size": 0, - }, - { - "id": self._2_2.id, - "manager_id": self._2.id, - "name": "2.2. Third H2", - "name_length": 13, - "team_size": 0, - }, - ], "name_length": 12, - "team_size": 2, + "team_size": 0, + }, + { + "id": _1_1_1.id, + "manager_id": _1_1.id, + "name": "1.1.1. First H3", + "name_length": 15, + "team_size": 0, }, ], - "name_length": 4, - "team_size": 2, + "team_members": [ + { + "id": _1_1.id, + "manager_id": _1.id, + "name": "1.1. First H2", + "name_length": 13, + "team_size": 0, + } + ], + "name_length": 11, + "team_size": 1, }, - ) - - -class TestPydanticComputed(test.TestCase): - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - self.Employee_Pydantic = pydantic_model_creator(Employee) - self.employee = await Employee.create(name="Some Employee") - self.maxDiff = None - - async def test_computed_field(self): - employee_pyd = await self.Employee_Pydantic.from_tortoise_orm( - await Employee.get(name="Some Employee") - ) - employee_serialised = employee_pyd.model_dump() - self.assertEqual(employee_serialised.get("name_length"), self.employee.name_length()) - - async def test_computed_field_schema(self): - self.assertEqual( - self.Employee_Pydantic.model_json_schema(mode="serialization"), { - "$defs": { - "Employee_fj2ly4_leaf": { - "additionalProperties": False, - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"maxLength": 50, "title": "Name", "type": "string"}, - "manager_id": { - "anyOf": [ - { - "maximum": 2147483647, - "minimum": -2147483648, - "type": "integer", - }, - {"type": "null"}, - ], - "default": None, - "nullable": True, - "title": "Manager Id", - }, - "name_length": { - "description": "", - "readOnly": True, - "title": "Name Length", - "type": "integer", - }, - "team_size": { - "description": "Computes team size.

Note that this function needs to be annotated with a return type so that pydantic can
generate a valid schema.

Note that the pydantic serializer can't call async methods, but the tortoise helpers
pre-fetch relational data, so that it is available before serialization. So we don't
need to await the relation. We do however have to protect against the case where no
prefetching was done, hence catching and handling the
``tortoise.exceptions.NoValuesFetched`` exception.", - "readOnly": True, - "title": "Team Size", - "type": "integer", - }, - }, - "required": ["id", "name", "name_length", "team_size"], - "title": "Employee", - "type": "object", - }, - "Employee_6tkbjb_leaf": { - "additionalProperties": False, - "properties": { - "id": { + "id": _2.id, + "manager_id": root.id, + "name": "2. Second H1", + "talks_to": [], + "team_members": [ + { + "id": _2_1.id, + "manager_id": _2.id, + "name": "2.1. Second H2", + "name_length": 14, + "team_size": 0, + }, + { + "id": _2_2.id, + "manager_id": _2.id, + "name": "2.2. Third H2", + "name_length": 13, + "team_size": 0, + }, + ], + "name_length": 12, + "team_size": 2, + }, + ], + "name_length": 4, + "team_size": 2, + } + + +# Fixtures for TestPydanticComputed +@pytest_asyncio.fixture +async def pydantic_computed_setup(db): + """Setup for pydantic computed field tests.""" + Employee_Pydantic = pydantic_model_creator(Employee) + employee = await Employee.create(name="Some Employee") + + return { + "Employee_Pydantic": Employee_Pydantic, + "employee": employee, + } + + +@pytest.mark.asyncio +async def test_computed_field(db, pydantic_computed_setup): + Employee_Pydantic = pydantic_computed_setup["Employee_Pydantic"] + employee = pydantic_computed_setup["employee"] + + employee_pyd = await Employee_Pydantic.from_tortoise_orm( + await Employee.get(name="Some Employee") + ) + employee_serialised = employee_pyd.model_dump() + assert employee_serialised.get("name_length") == employee.name_length() + + +@pytest.mark.asyncio +async def test_computed_field_schema(db, pydantic_computed_setup): + Employee_Pydantic = pydantic_computed_setup["Employee_Pydantic"] + assert Employee_Pydantic.model_json_schema(mode="serialization") == { + "$defs": { + "Employee_fj2ly4_leaf": { + "additionalProperties": False, + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"maxLength": 50, "title": "Name", "type": "string"}, + "manager_id": { + "anyOf": [ + { "maximum": 2147483647, "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"maxLength": 50, "title": "Name", "type": "string"}, - "talks_to": { - "items": {"$ref": "#/$defs/Employee_fj2ly4_leaf"}, - "title": "Talks To", - "type": "array", - }, - "manager_id": { - "anyOf": [ - { - "maximum": 2147483647, - "minimum": -2147483648, - "type": "integer", - }, - {"type": "null"}, - ], - "default": None, - "nullable": True, - "title": "Manager Id", - }, - "team_members": { - "items": {"$ref": "#/$defs/Employee_fj2ly4_leaf"}, - "title": "Team Members", - "type": "array", - }, - "name_length": { - "description": "", - "readOnly": True, - "title": "Name Length", "type": "integer", }, - "team_size": { - "description": "Computes team size.

Note that this function needs to be annotated with a return type so that pydantic can
generate a valid schema.

Note that the pydantic serializer can't call async methods, but the tortoise helpers
pre-fetch relational data, so that it is available before serialization. So we don't
need to await the relation. We do however have to protect against the case where no
prefetching was done, hence catching and handling the
``tortoise.exceptions.NoValuesFetched`` exception.", - "readOnly": True, - "title": "Team Size", - "type": "integer", - }, - }, - "required": [ - "id", - "name", - "talks_to", - "team_members", - "name_length", - "team_size", + {"type": "null"}, ], - "title": "Employee", - "type": "object", + "default": None, + "nullable": True, + "title": "Manager Id", + }, + "name_length": { + "description": "", + "readOnly": True, + "title": "Name Length", + "type": "integer", + }, + "team_size": { + "description": "Computes team size.

Note that this function needs to be annotated with a return type so that pydantic can
generate a valid schema.

Note that the pydantic serializer can't call async methods, but the tortoise helpers
pre-fetch relational data, so that it is available before serialization. So we don't
need to await the relation. We do however have to protect against the case where no
prefetching was done, hence catching and handling the
``tortoise.exceptions.NoValuesFetched`` exception.", + "readOnly": True, + "title": "Team Size", + "type": "integer", }, }, + "required": ["id", "name", "name_length", "team_size"], + "title": "Employee", + "type": "object", + }, + "Employee_6tkbjb_leaf": { "additionalProperties": False, "properties": { "id": { @@ -1724,13 +1763,17 @@ async def test_computed_field_schema(self): }, "name": {"maxLength": 50, "title": "Name", "type": "string"}, "talks_to": { - "items": {"$ref": "#/$defs/Employee_6tkbjb_leaf"}, + "items": {"$ref": "#/$defs/Employee_fj2ly4_leaf"}, "title": "Talks To", "type": "array", }, "manager_id": { "anyOf": [ - {"maximum": 2147483647, "minimum": -2147483648, "type": "integer"}, + { + "maximum": 2147483647, + "minimum": -2147483648, + "type": "integer", + }, {"type": "null"}, ], "default": None, @@ -1738,7 +1781,7 @@ async def test_computed_field_schema(self): "title": "Manager Id", }, "team_members": { - "items": {"$ref": "#/$defs/Employee_6tkbjb_leaf"}, + "items": {"$ref": "#/$defs/Employee_fj2ly4_leaf"}, "title": "Team Members", "type": "array", }, @@ -1755,364 +1798,396 @@ async def test_computed_field_schema(self): "type": "integer", }, }, - "required": ["id", "name", "talks_to", "team_members", "name_length", "team_size"], + "required": [ + "id", + "name", + "talks_to", + "team_members", + "name_length", + "team_size", + ], "title": "Employee", "type": "object", }, - ) - - -class TestPydanticUpdate(test.TestCase): - def setUp(self) -> None: - self.UserCreate_Pydantic = pydantic_model_creator( - User, - name="UserCreate", - exclude_readonly=True, - ) - self.UserUpdate_Pydantic = pydantic_model_creator( - User, - name="UserUpdate", - exclude_readonly=True, - optional=("username", "mail", "bio"), - ) - - def test_create_schema(self): - self.assertEqual( - self.UserCreate_Pydantic.model_json_schema(), - { - "title": "UserCreate", - "type": "object", - "properties": { - "username": { - "title": "Username", - "maxLength": 32, - "type": "string", - }, - "mail": { - "title": "Mail", - "maxLength": 64, - "type": "string", - }, - "bio": { - "title": "Bio", - "type": "string", - }, - }, - "required": [ - "username", - "mail", - "bio", + }, + "additionalProperties": False, + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": {"maxLength": 50, "title": "Name", "type": "string"}, + "talks_to": { + "items": {"$ref": "#/$defs/Employee_6tkbjb_leaf"}, + "title": "Talks To", + "type": "array", + }, + "manager_id": { + "anyOf": [ + {"maximum": 2147483647, "minimum": -2147483648, "type": "integer"}, + {"type": "null"}, ], - "additionalProperties": False, + "default": None, + "nullable": True, + "title": "Manager Id", }, - ) - - def test_update_schema(self): - """All fields of this schema should be optional. - This demonstrates an example PATCH endpoint in an API, where a client may want - to update a single field of a model without modifying the rest. - """ - self.assertEqual( - self.UserUpdate_Pydantic.model_json_schema(), - { - "additionalProperties": False, - "properties": { - "bio": { - "anyOf": [{"type": "string"}, {"type": "null"}], - "default": None, - "title": "Bio", - }, - "mail": { - "anyOf": [{"maxLength": 64, "type": "string"}, {"type": "null"}], - "default": None, - "title": "Mail", - }, - "username": { - "anyOf": [{"maxLength": 32, "type": "string"}, {"type": "null"}], - "default": None, - "title": "Username", - }, - }, - "title": "UserUpdate", - "type": "object", + "team_members": { + "items": {"$ref": "#/$defs/Employee_6tkbjb_leaf"}, + "title": "Team Members", + "type": "array", }, - ) - - -class TestPydanticOptionalUpdate(test.TestCase): - def setUp(self) -> None: - self.UserUpdateAllOptional_Pydantic = pydantic_model_creator( - User, - name="UserUpdateAllOptional", - exclude_readonly=True, - optional=("username", "mail", "bio"), - ) - self.UserUpdatePartialOptional_Pydantic = pydantic_model_creator( - User, - name="UserUpdatePartialOptional", - exclude_readonly=True, - optional=("username", "mail"), - ) - self.UserUpdateWithoutOptional_Pydantic = pydantic_model_creator( - User, - name="UserUpdateWithoutOptional", - exclude_readonly=True, - ) - - def test_optional_update(self): - # All fields are optional - self.assertEqual(self.UserUpdateAllOptional_Pydantic().model_dump(exclude_unset=True), {}) - self.assertEqual( - self.UserUpdateAllOptional_Pydantic(bio="foo").model_dump(exclude_unset=True), - {"bio": "foo"}, - ) - self.assertEqual( - self.UserUpdateAllOptional_Pydantic(username="name", mail="a@example.com").model_dump( - exclude_unset=True - ), - {"username": "name", "mail": "a@example.com"}, - ) - self.assertEqual( - self.UserUpdateAllOptional_Pydantic(username="name", mail="a@example.com").model_dump(), - {"username": "name", "mail": "a@example.com", "bio": None}, - ) - # Some fields are optional - with pytest.raises(ValidationError): - self.UserUpdatePartialOptional_Pydantic() - with pytest.raises(ValidationError): - self.UserUpdatePartialOptional_Pydantic(username="name") - self.assertEqual( - self.UserUpdatePartialOptional_Pydantic(bio="foo").model_dump(exclude_unset=True), - {"bio": "foo"}, - ) - self.assertEqual( - self.UserUpdatePartialOptional_Pydantic( - username="name", mail="a@example.com", bio="" - ).model_dump(exclude_unset=True), - {"username": "name", "mail": "a@example.com", "bio": ""}, - ) - self.assertEqual( - self.UserUpdatePartialOptional_Pydantic(mail="a@example.com", bio="").model_dump(), - {"username": None, "mail": "a@example.com", "bio": ""}, - ) - # None of the fields is optional - with pytest.raises(ValidationError): - self.UserUpdateWithoutOptional_Pydantic() - with pytest.raises(ValidationError): - self.UserUpdateWithoutOptional_Pydantic(username="name") - with pytest.raises(ValidationError): - self.UserUpdateWithoutOptional_Pydantic(username="name", email="") - self.assertEqual( - self.UserUpdateWithoutOptional_Pydantic( - username="name", mail="a@example.com", bio="" - ).model_dump(), - {"username": "name", "mail": "a@example.com", "bio": ""}, - ) - - -class TestPydanticMutlipleModelUses(test.TestCase): - def setUp(self) -> None: - self.NoRelationsModel = IntFields - self.ModelWithRelations = Event - - def test_no_relations_model_reused(self): - Pydantic1 = pydantic_model_creator(self.NoRelationsModel) - Pydantic2 = pydantic_model_creator(self.NoRelationsModel) - - self.assertIs(Pydantic1, Pydantic2) - - def test_no_relations_model_one_exclude(self): - Pydantic1 = pydantic_model_creator(self.NoRelationsModel) - Pydantic2 = pydantic_model_creator(self.NoRelationsModel, exclude=("id",)) - - self.assertIsNot(Pydantic1, Pydantic2) - self.assertIn("id", Pydantic1.model_json_schema()["required"]) - self.assertNotIn("id", Pydantic2.model_json_schema()["required"]) - - def test_no_relations_model_both_exclude(self): - Pydantic1 = pydantic_model_creator(self.NoRelationsModel, exclude=("id",)) - Pydantic2 = pydantic_model_creator(self.NoRelationsModel, exclude=("id",)) - - self.assertIs(Pydantic1, Pydantic2) - self.assertNotIn("id", Pydantic1.model_json_schema()["required"]) - self.assertNotIn("id", Pydantic2.model_json_schema()["required"]) - - def test_no_relations_model_exclude_diff(self): - Pydantic1 = pydantic_model_creator(self.NoRelationsModel, exclude=("id",)) - Pydantic2 = pydantic_model_creator(self.NoRelationsModel, exclude=("name",)) - - self.assertIsNot(Pydantic1, Pydantic2) - - def test_no_relations_model_exclude_readonly(self): - Pydantic1 = pydantic_model_creator(self.NoRelationsModel) - Pydantic2 = pydantic_model_creator(self.NoRelationsModel, exclude_readonly=True) - - self.assertIsNot(Pydantic1, Pydantic2) - self.assertIn("id", Pydantic1.model_json_schema()["properties"]) - self.assertNotIn("id", Pydantic2.model_json_schema()["properties"]) - - def test_model_with_relations_reused(self): - Pydantic1 = pydantic_model_creator(self.ModelWithRelations) - Pydantic2 = pydantic_model_creator(self.ModelWithRelations) - - self.assertIs(Pydantic1, Pydantic2) - - def test_model_with_relations_exclude(self): - Pydantic1 = pydantic_model_creator(self.ModelWithRelations) - Pydantic2 = pydantic_model_creator(self.ModelWithRelations, exclude=("event_id",)) - - self.assertIsNot(Pydantic1, Pydantic2) - self.assertIn("event_id", Pydantic1.model_json_schema()["properties"]) - self.assertNotIn("event_id", Pydantic2.model_json_schema()["properties"]) - - def test_model_with_relations_exclude_readonly(self): - Pydantic1 = pydantic_model_creator(self.ModelWithRelations) - Pydantic2 = pydantic_model_creator(self.ModelWithRelations, exclude_readonly=True) - - self.assertIsNot(Pydantic1, Pydantic2) - self.assertIn("event_id", Pydantic1.model_json_schema()["properties"]) - self.assertNotIn("event_id", Pydantic2.model_json_schema()["properties"]) - - def test_named_no_relations_model(self): - Pydantic1 = pydantic_model_creator(self.NoRelationsModel, name="Foo") - Pydantic2 = pydantic_model_creator(self.NoRelationsModel, name="Foo") - - self.assertIs(Pydantic1, Pydantic2) - - def test_named_model_with_relations(self): - Pydantic1 = pydantic_model_creator(self.ModelWithRelations, name="Foo") - Pydantic2 = pydantic_model_creator(self.ModelWithRelations, name="Foo") + "name_length": { + "description": "", + "readOnly": True, + "title": "Name Length", + "type": "integer", + }, + "team_size": { + "description": "Computes team size.

Note that this function needs to be annotated with a return type so that pydantic can
generate a valid schema.

Note that the pydantic serializer can't call async methods, but the tortoise helpers
pre-fetch relational data, so that it is available before serialization. So we don't
need to await the relation. We do however have to protect against the case where no
prefetching was done, hence catching and handling the
``tortoise.exceptions.NoValuesFetched`` exception.", + "readOnly": True, + "title": "Team Size", + "type": "integer", + }, + }, + "required": ["id", "name", "talks_to", "team_members", "name_length", "team_size"], + "title": "Employee", + "type": "object", + } + + +# Tests for TestPydanticUpdate +def test_create_schema(db): + UserCreate_Pydantic = pydantic_model_creator( + User, + name="UserCreate", + exclude_readonly=True, + ) + assert UserCreate_Pydantic.model_json_schema() == { + "title": "UserCreate", + "type": "object", + "properties": { + "username": { + "title": "Username", + "maxLength": 32, + "type": "string", + }, + "mail": { + "title": "Mail", + "maxLength": 64, + "type": "string", + }, + "bio": { + "title": "Bio", + "type": "string", + }, + }, + "required": [ + "username", + "mail", + "bio", + ], + "additionalProperties": False, + } + + +def test_update_schema(db): + """All fields of this schema should be optional. + This demonstrates an example PATCH endpoint in an API, where a client may want + to update a single field of a model without modifying the rest. + """ + UserUpdate_Pydantic = pydantic_model_creator( + User, + name="UserUpdate", + exclude_readonly=True, + optional=("username", "mail", "bio"), + ) + assert UserUpdate_Pydantic.model_json_schema() == { + "additionalProperties": False, + "properties": { + "bio": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "title": "Bio", + }, + "mail": { + "anyOf": [{"maxLength": 64, "type": "string"}, {"type": "null"}], + "default": None, + "title": "Mail", + }, + "username": { + "anyOf": [{"maxLength": 32, "type": "string"}, {"type": "null"}], + "default": None, + "title": "Username", + }, + }, + "title": "UserUpdate", + "type": "object", + } + + +# Tests for TestPydanticOptionalUpdate +def test_optional_update(db): + UserUpdateAllOptional_Pydantic = pydantic_model_creator( + User, + name="UserUpdateAllOptional", + exclude_readonly=True, + optional=("username", "mail", "bio"), + ) + UserUpdatePartialOptional_Pydantic = pydantic_model_creator( + User, + name="UserUpdatePartialOptional", + exclude_readonly=True, + optional=("username", "mail"), + ) + UserUpdateWithoutOptional_Pydantic = pydantic_model_creator( + User, + name="UserUpdateWithoutOptional", + exclude_readonly=True, + ) + + # All fields are optional + assert UserUpdateAllOptional_Pydantic().model_dump(exclude_unset=True) == {} + assert UserUpdateAllOptional_Pydantic(bio="foo").model_dump(exclude_unset=True) == { + "bio": "foo" + } + assert UserUpdateAllOptional_Pydantic(username="name", mail="a@example.com").model_dump( + exclude_unset=True + ) == {"username": "name", "mail": "a@example.com"} + assert UserUpdateAllOptional_Pydantic(username="name", mail="a@example.com").model_dump() == { + "username": "name", + "mail": "a@example.com", + "bio": None, + } + # Some fields are optional + with pytest.raises(ValidationError): + UserUpdatePartialOptional_Pydantic() + with pytest.raises(ValidationError): + UserUpdatePartialOptional_Pydantic(username="name") + assert UserUpdatePartialOptional_Pydantic(bio="foo").model_dump(exclude_unset=True) == { + "bio": "foo" + } + assert UserUpdatePartialOptional_Pydantic( + username="name", mail="a@example.com", bio="" + ).model_dump(exclude_unset=True) == {"username": "name", "mail": "a@example.com", "bio": ""} + assert UserUpdatePartialOptional_Pydantic(mail="a@example.com", bio="").model_dump() == { + "username": None, + "mail": "a@example.com", + "bio": "", + } + # None of the fields is optional + with pytest.raises(ValidationError): + UserUpdateWithoutOptional_Pydantic() + with pytest.raises(ValidationError): + UserUpdateWithoutOptional_Pydantic(username="name") + with pytest.raises(ValidationError): + UserUpdateWithoutOptional_Pydantic(username="name", email="") + assert UserUpdateWithoutOptional_Pydantic( + username="name", mail="a@example.com", bio="" + ).model_dump() == {"username": "name", "mail": "a@example.com", "bio": ""} + + +# Tests for TestPydanticMutlipleModelUses +def test_no_relations_model_reused(db): + NoRelationsModel = IntFields + Pydantic1 = pydantic_model_creator(NoRelationsModel) + Pydantic2 = pydantic_model_creator(NoRelationsModel) + + assert Pydantic1 is Pydantic2 + + +def test_no_relations_model_one_exclude(db): + NoRelationsModel = IntFields + Pydantic1 = pydantic_model_creator(NoRelationsModel) + Pydantic2 = pydantic_model_creator(NoRelationsModel, exclude=("id",)) + + assert Pydantic1 is not Pydantic2 + assert "id" in Pydantic1.model_json_schema()["required"] + assert "id" not in Pydantic2.model_json_schema()["required"] + + +def test_no_relations_model_both_exclude(db): + NoRelationsModel = IntFields + Pydantic1 = pydantic_model_creator(NoRelationsModel, exclude=("id",)) + Pydantic2 = pydantic_model_creator(NoRelationsModel, exclude=("id",)) + + assert Pydantic1 is Pydantic2 + assert "id" not in Pydantic1.model_json_schema()["required"] + assert "id" not in Pydantic2.model_json_schema()["required"] + + +def test_no_relations_model_exclude_diff(db): + NoRelationsModel = IntFields + Pydantic1 = pydantic_model_creator(NoRelationsModel, exclude=("id",)) + Pydantic2 = pydantic_model_creator(NoRelationsModel, exclude=("name",)) + + assert Pydantic1 is not Pydantic2 + + +def test_no_relations_model_exclude_readonly(db): + NoRelationsModel = IntFields + Pydantic1 = pydantic_model_creator(NoRelationsModel) + Pydantic2 = pydantic_model_creator(NoRelationsModel, exclude_readonly=True) + + assert Pydantic1 is not Pydantic2 + assert "id" in Pydantic1.model_json_schema()["properties"] + assert "id" not in Pydantic2.model_json_schema()["properties"] + + +def test_model_with_relations_reused(db): + ModelWithRelations = Event + Pydantic1 = pydantic_model_creator(ModelWithRelations) + Pydantic2 = pydantic_model_creator(ModelWithRelations) + + assert Pydantic1 is Pydantic2 + + +def test_model_with_relations_exclude(db): + ModelWithRelations = Event + Pydantic1 = pydantic_model_creator(ModelWithRelations) + Pydantic2 = pydantic_model_creator(ModelWithRelations, exclude=("event_id",)) + + assert Pydantic1 is not Pydantic2 + assert "event_id" in Pydantic1.model_json_schema()["properties"] + assert "event_id" not in Pydantic2.model_json_schema()["properties"] + + +def test_model_with_relations_exclude_readonly(db): + ModelWithRelations = Event + Pydantic1 = pydantic_model_creator(ModelWithRelations) + Pydantic2 = pydantic_model_creator(ModelWithRelations, exclude_readonly=True) + + assert Pydantic1 is not Pydantic2 + assert "event_id" in Pydantic1.model_json_schema()["properties"] + assert "event_id" not in Pydantic2.model_json_schema()["properties"] + + +def test_named_no_relations_model(db): + NoRelationsModel = IntFields + Pydantic1 = pydantic_model_creator(NoRelationsModel, name="Foo") + Pydantic2 = pydantic_model_creator(NoRelationsModel, name="Foo") - self.assertIs(Pydantic1, Pydantic2) + assert Pydantic1 is Pydantic2 -class TestPydanticEnum(test.TestCase): - def setUp(self) -> None: - self.EnumFields_Pydantic = pydantic_model_creator(EnumFields) +def test_named_model_with_relations(db): + ModelWithRelations = Event + Pydantic1 = pydantic_model_creator(ModelWithRelations, name="Foo") + Pydantic2 = pydantic_model_creator(ModelWithRelations, name="Foo") - def test_int_enum(self): - with self.assertRaises(ValidationError) as cm: - self.EnumFields_Pydantic.model_validate({"id": 1, "service": 4, "currency": "HUF"}) - self.assertEqual( - [ - { - "type": "enum", - "loc": ("service",), - "msg": "Input should be 1, 2 or 3", - "input": 4, - "ctx": {"expected": "1, 2 or 3"}, - } - ], - cm.exception.errors(include_url=False), - ) - with self.assertRaises(ValidationError) as cm: - self.EnumFields_Pydantic.model_validate( - {"id": 1, "service": "a string, not int", "currency": "HUF"} - ) - self.assertEqual( - [ - { - "type": "enum", - "loc": ("service",), - "msg": "Input should be 1, 2 or 3", - "input": "a string, not int", - "ctx": {"expected": "1, 2 or 3"}, - } - ], - cm.exception.errors(include_url=False), - ) + assert Pydantic1 is Pydantic2 - def test_str_enum(self): - with self.assertRaises(ValidationError) as cm: - self.EnumFields_Pydantic.model_validate( - {"id": 1, "service": 3, "currency": "GoofyGooberDollar"} - ) - self.assertEqual( - [ - { - "type": "enum", - "loc": ("currency",), - "msg": "Input should be 'HUF', 'EUR' or 'USD'", - "input": "GoofyGooberDollar", - "ctx": {"expected": "'HUF', 'EUR' or 'USD'"}, - } - ], - cm.exception.errors(include_url=False), - ) - with self.assertRaises(ValidationError) as cm: - self.EnumFields_Pydantic.model_validate({"id": 1, "service": 3, "currency": 1}) - self.assertEqual( - [ - { - "type": "enum", - "loc": ("currency",), - "msg": "Input should be 'HUF', 'EUR' or 'USD'", - "input": 1, - "ctx": {"expected": "'HUF', 'EUR' or 'USD'"}, - } - ], - cm.exception.errors(include_url=False), - ) - def test_enum(self): - with self.assertRaises(ValidationError) as cm: - self.EnumFields_Pydantic.model_validate({"id": 1, "service": 4, "currency": 1}) - self.assertEqual( - [ - { - "type": "enum", - "loc": ("service",), - "msg": "Input should be 1, 2 or 3", - "input": 4, - "ctx": {"expected": "1, 2 or 3"}, - }, - { - "type": "enum", - "loc": ("currency",), - "msg": "Input should be 'HUF', 'EUR' or 'USD'", - "input": 1, - "ctx": {"expected": "'HUF', 'EUR' or 'USD'"}, - }, - ], - cm.exception.errors(include_url=False), +# Tests for TestPydanticEnum +def test_int_enum(db): + EnumFields_Pydantic = pydantic_model_creator(EnumFields) + with pytest.raises(ValidationError) as exc_info: + EnumFields_Pydantic.model_validate({"id": 1, "service": 4, "currency": "HUF"}) + assert [ + { + "type": "enum", + "loc": ("service",), + "msg": "Input should be 1, 2 or 3", + "input": 4, + "ctx": {"expected": "1, 2 or 3"}, + } + ] == exc_info.value.errors(include_url=False) + with pytest.raises(ValidationError) as exc_info: + EnumFields_Pydantic.model_validate( + {"id": 1, "service": "a string, not int", "currency": "HUF"} ) - - # should simply not raise any error: - self.EnumFields_Pydantic.model_validate({"id": 1, "service": 3, "currency": "HUF"}) - self.assertEqual( - { - "$defs": { - "Currency": { - "enum": ["HUF", "EUR", "USD"], - "title": "Currency", - "type": "string", - }, - "Service": {"enum": [1, 2, 3], "title": "Service", "type": "integer"}, - }, - "additionalProperties": False, - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "service": { - "$ref": "#/$defs/Service", - "description": "python_programming: 1
database_design: 2
system_administration: 3", - "ge": -32768, - "le": 32767, - }, - "currency": { - "$ref": "#/$defs/Currency", - "default": "HUF", - "description": "HUF: HUF
EUR: EUR
USD: USD", - "maxLength": 3, - }, - }, - "required": ["id", "service"], - "title": "EnumFields", - "type": "object", + assert [ + { + "type": "enum", + "loc": ("service",), + "msg": "Input should be 1, 2 or 3", + "input": "a string, not int", + "ctx": {"expected": "1, 2 or 3"}, + } + ] == exc_info.value.errors(include_url=False) + + +def test_str_enum(db): + EnumFields_Pydantic = pydantic_model_creator(EnumFields) + with pytest.raises(ValidationError) as exc_info: + EnumFields_Pydantic.model_validate({"id": 1, "service": 3, "currency": "GoofyGooberDollar"}) + assert [ + { + "type": "enum", + "loc": ("currency",), + "msg": "Input should be 'HUF', 'EUR' or 'USD'", + "input": "GoofyGooberDollar", + "ctx": {"expected": "'HUF', 'EUR' or 'USD'"}, + } + ] == exc_info.value.errors(include_url=False) + with pytest.raises(ValidationError) as exc_info: + EnumFields_Pydantic.model_validate({"id": 1, "service": 3, "currency": 1}) + assert [ + { + "type": "enum", + "loc": ("currency",), + "msg": "Input should be 'HUF', 'EUR' or 'USD'", + "input": 1, + "ctx": {"expected": "'HUF', 'EUR' or 'USD'"}, + } + ] == exc_info.value.errors(include_url=False) + + +def test_enum(db): + EnumFields_Pydantic = pydantic_model_creator(EnumFields) + with pytest.raises(ValidationError) as exc_info: + EnumFields_Pydantic.model_validate({"id": 1, "service": 4, "currency": 1}) + assert [ + { + "type": "enum", + "loc": ("service",), + "msg": "Input should be 1, 2 or 3", + "input": 4, + "ctx": {"expected": "1, 2 or 3"}, + }, + { + "type": "enum", + "loc": ("currency",), + "msg": "Input should be 'HUF', 'EUR' or 'USD'", + "input": 1, + "ctx": {"expected": "'HUF', 'EUR' or 'USD'"}, + }, + ] == exc_info.value.errors(include_url=False) + + # should simply not raise any error: + EnumFields_Pydantic.model_validate({"id": 1, "service": 3, "currency": "HUF"}) + assert { + "$defs": { + "Currency": { + "enum": ["HUF", "EUR", "USD"], + "title": "Currency", + "type": "string", }, - self.EnumFields_Pydantic.model_json_schema(), - ) + "Service": {"enum": [1, 2, 3], "title": "Service", "type": "integer"}, + }, + "additionalProperties": False, + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "service": { + "$ref": "#/$defs/Service", + "description": "python_programming: 1
database_design: 2
system_administration: 3", + "ge": -32768, + "le": 32767, + }, + "currency": { + "$ref": "#/$defs/Currency", + "default": "HUF", + "description": "HUF: HUF
EUR: EUR
USD: USD", + "maxLength": 3, + }, + }, + "required": ["id", "service"], + "title": "EnumFields", + "type": "object", + } == EnumFields_Pydantic.model_json_schema() diff --git a/tests/contrib/test_tester.py b/tests/contrib/test_tester.py deleted file mode 100644 index 39538700f..000000000 --- a/tests/contrib/test_tester.py +++ /dev/null @@ -1,41 +0,0 @@ -# pylint: disable=W1503 -from tortoise.contrib import test - - -class TestTesterSync(test.SimpleTestCase): - def setUp(self): - self.moo = "SET" - - def tearDown(self): - self.assertEqual(self.moo, "SET") - - @test.skip("Skip it") - def test_skip(self): - self.assertTrue(False) - - @test.expectedFailure - def test_fail(self): - self.assertTrue(False) - - def test_moo(self): - self.assertEqual(self.moo, "SET") - - -class TestTesterASync(test.SimpleTestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - self.baa = "TES" - - def tearDown(self): - self.assertEqual(self.baa, "TES") - - @test.skip("Skip it") - async def test_skip(self): - self.assertTrue(False) - - @test.expectedFailure - async def test_fail(self): - self.assertTrue(False) - - async def test_moo(self): - self.assertEqual(self.baa, "TES") diff --git a/tests/fields/conftest.py b/tests/fields/conftest.py new file mode 100644 index 000000000..90ad78ddd --- /dev/null +++ b/tests/fields/conftest.py @@ -0,0 +1,52 @@ +""" +Custom fixtures for field tests that require specific model modules. + +These fixtures support tests that define tortoise_test_modules to use +custom model definitions instead of the default tests.testmodels. +""" + +import os + +import pytest +import pytest_asyncio + +from tortoise.context import tortoise_test_context + + +@pytest_asyncio.fixture(scope="function") +async def db_array_fields(): + """ + Fixture for TestArrayFields. + + Uses models defined in tests.testmodels_postgres module. + Equivalent to: test.IsolatedTestCase with tortoise_test_modules=["tests.testmodels_postgres"] + """ + db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") + # Skip on non-Postgres databases since ArrayFields require Postgres + if "postgres" not in db_url: + pytest.skip("ArrayFields require PostgreSQL") + async with tortoise_test_context( + modules=["tests.testmodels_postgres"], + db_url=db_url, + app_label="models", + connection_label="models", + ) as ctx: + yield ctx + + +@pytest_asyncio.fixture(scope="function") +async def db_subclass_fields(): + """ + Fixture for TestEnumField and TestCustomFieldFilters. + + Uses models defined in tests.fields.subclass_models module. + Equivalent to: test.IsolatedTestCase with tortoise_test_modules=["tests.fields.subclass_models"] + """ + db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") + async with tortoise_test_context( + modules=["tests.fields.subclass_models"], + db_url=db_url, + app_label="models", + connection_label="models", + ) as ctx: + yield ctx diff --git a/tests/fields/test_array.py b/tests/fields/test_array.py index 2bf192dd2..8b3e78d11 100644 --- a/tests/fields/test_array.py +++ b/tests/fields/test_array.py @@ -1,146 +1,193 @@ +import pytest + from tests import testmodels_postgres as testmodels -from tortoise.contrib import test -from tortoise.exceptions import IntegrityError, OperationalError +from tortoise.contrib.test import requireCapability +from tortoise.exceptions import IntegrityError + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_empty(db_array_fields): + """Test that creating without required array field raises IntegrityError.""" + with pytest.raises(IntegrityError): + await testmodels.ArrayFields.create() + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_create(db_array_fields): + """Test array field creation and retrieval.""" + obj0 = await testmodels.ArrayFields.create(array=[0]) + obj = await testmodels.ArrayFields.get(id=obj0.id) + assert obj.array == [0] + assert obj.array_null is None + await obj.save() + obj2 = await testmodels.ArrayFields.get(id=obj.id) + assert obj == obj2 + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_update(db_array_fields): + """Test array field update.""" + obj0 = await testmodels.ArrayFields.create(array=[0]) + await testmodels.ArrayFields.filter(id=obj0.id).update(array=[1]) + obj = await testmodels.ArrayFields.get(id=obj0.id) + assert obj.array == [1] + assert obj.array_null is None + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_values(db_array_fields): + """Test array field in values().""" + obj0 = await testmodels.ArrayFields.create(array=[0]) + values = await testmodels.ArrayFields.get(id=obj0.id).values("array") + assert values["array"] == [0] + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_values_list(db_array_fields): + """Test array field in values_list().""" + obj0 = await testmodels.ArrayFields.create(array=[0]) + values = await testmodels.ArrayFields.get(id=obj0.id).values_list("array", flat=True) + assert values == [0] + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_eq_filter(db_array_fields): + """Test equality filter on array field.""" + obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3]) + obj2 = await testmodels.ArrayFields.create(array=[1, 2]) -@test.requireCapability(dialect="postgres") -class TestArrayFields(test.IsolatedTestCase): - tortoise_test_modules = ["tests.testmodels_postgres"] + found = await testmodels.ArrayFields.filter(array=[1, 2, 3]).first() + assert found == obj1 - async def _setUpDB(self) -> None: - try: - await super()._setUpDB() - except OperationalError: - raise test.SkipTest("Works only with PostgreSQL") + found = await testmodels.ArrayFields.filter(array=[1, 2]).first() + assert found == obj2 - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels.ArrayFields.create() - async def test_create(self): - obj0 = await testmodels.ArrayFields.create(array=[0]) - obj = await testmodels.ArrayFields.get(id=obj0.id) - self.assertEqual(obj.array, [0]) - self.assertIs(obj.array_null, None) - await obj.save() - obj2 = await testmodels.ArrayFields.get(id=obj.id) - self.assertEqual(obj, obj2) +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_not_filter(db_array_fields): + """Test not filter on array field.""" + await testmodels.ArrayFields.create(array=[1, 2, 3]) + obj2 = await testmodels.ArrayFields.create(array=[1, 2]) - async def test_update(self): - obj0 = await testmodels.ArrayFields.create(array=[0]) - await testmodels.ArrayFields.filter(id=obj0.id).update(array=[1]) - obj = await testmodels.ArrayFields.get(id=obj0.id) - self.assertEqual(obj.array, [1]) - self.assertIs(obj.array_null, None) + found = await testmodels.ArrayFields.filter(array__not=[1, 2, 3]).first() + assert found == obj2 - async def test_values(self): - obj0 = await testmodels.ArrayFields.create(array=[0]) - values = await testmodels.ArrayFields.get(id=obj0.id).values("array") - self.assertEqual(values["array"], [0]) - async def test_values_list(self): - obj0 = await testmodels.ArrayFields.create(array=[0]) - values = await testmodels.ArrayFields.get(id=obj0.id).values_list("array", flat=True) - self.assertEqual(values, [0]) +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_contains_ints(db_array_fields): + """Test contains filter on integer array field.""" + obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3]) + obj2 = await testmodels.ArrayFields.create(array=[2, 3]) + await testmodels.ArrayFields.create(array=[4, 5, 6]) - async def test_eq_filter(self): - obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3]) - obj2 = await testmodels.ArrayFields.create(array=[1, 2]) + found = await testmodels.ArrayFields.filter(array__contains=[2]) + assert found == [obj1, obj2] - found = await testmodels.ArrayFields.filter(array=[1, 2, 3]).first() - self.assertEqual(found, obj1) + found = await testmodels.ArrayFields.filter(array__contains=[10]) + assert found == [] - found = await testmodels.ArrayFields.filter(array=[1, 2]).first() - self.assertEqual(found, obj2) - async def test_not_filter(self): - await testmodels.ArrayFields.create(array=[1, 2, 3]) - obj2 = await testmodels.ArrayFields.create(array=[1, 2]) +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_contains_smallints(db_array_fields): + """Test contains filter on smallint array field.""" + obj1 = await testmodels.ArrayFields.create(array=[], array_smallint=[1, 2, 3]) - found = await testmodels.ArrayFields.filter(array__not=[1, 2, 3]).first() - self.assertEqual(found, obj2) + found = await testmodels.ArrayFields.filter(array_smallint__contains=[2]).first() + assert found == obj1 - async def test_contains_ints(self): - obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3]) - obj2 = await testmodels.ArrayFields.create(array=[2, 3]) - await testmodels.ArrayFields.create(array=[4, 5, 6]) - found = await testmodels.ArrayFields.filter(array__contains=[2]) - self.assertEqual(found, [obj1, obj2]) +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_contains_strs(db_array_fields): + """Test contains filter on string array field.""" + obj1 = await testmodels.ArrayFields.create(array_str=["a", "b", "c"], array=[]) - found = await testmodels.ArrayFields.filter(array__contains=[10]) - self.assertEqual(found, []) + found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b", "c"]) + assert found == [obj1] - async def test_contains_smallints(self): - obj1 = await testmodels.ArrayFields.create(array=[], array_smallint=[1, 2, 3]) + found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b"]) + assert found == [obj1] - found = await testmodels.ArrayFields.filter(array_smallint__contains=[2]).first() - self.assertEqual(found, obj1) + found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b", "c", "d"]) + assert found == [] - async def test_contains_strs(self): - obj1 = await testmodels.ArrayFields.create(array_str=["a", "b", "c"], array=[]) - found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b", "c"]) - self.assertEqual(found, [obj1]) +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_contained_by_ints(db_array_fields): + """Test contained_by filter on integer array field.""" + obj1 = await testmodels.ArrayFields.create(array=[1]) + obj2 = await testmodels.ArrayFields.create(array=[1, 2]) + obj3 = await testmodels.ArrayFields.create(array=[1, 2, 3]) - found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b"]) - self.assertEqual(found, [obj1]) + found = await testmodels.ArrayFields.filter(array__contained_by=[1, 2, 3]) + assert found == [obj1, obj2, obj3] - found = await testmodels.ArrayFields.filter(array_str__contains=["a", "b", "c", "d"]) - self.assertEqual(found, []) + found = await testmodels.ArrayFields.filter(array__contained_by=[1, 2]) + assert found == [obj1, obj2] - async def test_contained_by_ints(self): - obj1 = await testmodels.ArrayFields.create(array=[1]) - obj2 = await testmodels.ArrayFields.create(array=[1, 2]) - obj3 = await testmodels.ArrayFields.create(array=[1, 2, 3]) + found = await testmodels.ArrayFields.filter(array__contained_by=[1]) + assert found == [obj1] - found = await testmodels.ArrayFields.filter(array__contained_by=[1, 2, 3]) - self.assertEqual(found, [obj1, obj2, obj3]) - found = await testmodels.ArrayFields.filter(array__contained_by=[1, 2]) - self.assertEqual(found, [obj1, obj2]) +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_contained_by_strs(db_array_fields): + """Test contained_by filter on string array field.""" + obj1 = await testmodels.ArrayFields.create(array_str=["a"], array=[]) + obj2 = await testmodels.ArrayFields.create(array_str=["a", "b"], array=[]) + obj3 = await testmodels.ArrayFields.create(array_str=["a", "b", "c"], array=[]) - found = await testmodels.ArrayFields.filter(array__contained_by=[1]) - self.assertEqual(found, [obj1]) + found = await testmodels.ArrayFields.filter(array_str__contained_by=["a", "b", "c", "d"]) + assert found == [obj1, obj2, obj3] - async def test_contained_by_strs(self): - obj1 = await testmodels.ArrayFields.create(array_str=["a"], array=[]) - obj2 = await testmodels.ArrayFields.create(array_str=["a", "b"], array=[]) - obj3 = await testmodels.ArrayFields.create(array_str=["a", "b", "c"], array=[]) + found = await testmodels.ArrayFields.filter(array_str__contained_by=["a", "b"]) + assert found == [obj1, obj2] - found = await testmodels.ArrayFields.filter(array_str__contained_by=["a", "b", "c", "d"]) - self.assertEqual(found, [obj1, obj2, obj3]) + found = await testmodels.ArrayFields.filter(array_str__contained_by=["x", "y", "z"]) + assert found == [] - found = await testmodels.ArrayFields.filter(array_str__contained_by=["a", "b"]) - self.assertEqual(found, [obj1, obj2]) - found = await testmodels.ArrayFields.filter(array_str__contained_by=["x", "y", "z"]) - self.assertEqual(found, []) +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_overlap_ints(db_array_fields): + """Test overlap filter on integer array field.""" + obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3]) + obj2 = await testmodels.ArrayFields.create(array=[2, 3, 4]) + obj3 = await testmodels.ArrayFields.create(array=[3, 4, 5]) - async def test_overlap_ints(self): - obj1 = await testmodels.ArrayFields.create(array=[1, 2, 3]) - obj2 = await testmodels.ArrayFields.create(array=[2, 3, 4]) - obj3 = await testmodels.ArrayFields.create(array=[3, 4, 5]) + found = await testmodels.ArrayFields.filter(array__overlap=[1, 2]) + assert found == [obj1, obj2] - found = await testmodels.ArrayFields.filter(array__overlap=[1, 2]) - self.assertEqual(found, [obj1, obj2]) + found = await testmodels.ArrayFields.filter(array__overlap=[4]) + assert found == [obj2, obj3] - found = await testmodels.ArrayFields.filter(array__overlap=[4]) - self.assertEqual(found, [obj2, obj3]) + found = await testmodels.ArrayFields.filter(array__overlap=[1, 2, 3, 4, 5]) + assert found == [obj1, obj2, obj3] - found = await testmodels.ArrayFields.filter(array__overlap=[1, 2, 3, 4, 5]) - self.assertEqual(found, [obj1, obj2, obj3]) - async def test_array_length(self): - await testmodels.ArrayFields.create(array=[1, 2, 3]) - await testmodels.ArrayFields.create(array=[1]) - await testmodels.ArrayFields.create(array=[1, 2]) +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_array_length(db_array_fields): + """Test array length filter.""" + await testmodels.ArrayFields.create(array=[1, 2, 3]) + await testmodels.ArrayFields.create(array=[1]) + await testmodels.ArrayFields.create(array=[1, 2]) - found = await testmodels.ArrayFields.filter(array__len=3).values_list("array", flat=True) - self.assertEqual(list(found), [[1, 2, 3]]) + found = await testmodels.ArrayFields.filter(array__len=3).values_list("array", flat=True) + assert list(found) == [[1, 2, 3]] - found = await testmodels.ArrayFields.filter(array__len=1).values_list("array", flat=True) - self.assertEqual(list(found), [[1]]) + found = await testmodels.ArrayFields.filter(array__len=1).values_list("array", flat=True) + assert list(found) == [[1]] - found = await testmodels.ArrayFields.filter(array__len=0).values_list("array", flat=True) - self.assertEqual(list(found), []) + found = await testmodels.ArrayFields.filter(array__len=0).values_list("array", flat=True) + assert list(found) == [] diff --git a/tests/fields/test_binary.py b/tests/fields/test_binary.py index dfbcc74f8..a7713f7ac 100644 --- a/tests/fields/test_binary.py +++ b/tests/fields/test_binary.py @@ -1,45 +1,54 @@ +import pytest + from tests import testmodels -from tortoise.contrib import test from tortoise.exceptions import ConfigurationError, IntegrityError from tortoise.fields import BinaryField -class TestBinaryFields(test.TestCase): - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels.BinaryFields.create() - - async def test_create(self): - obj0 = await testmodels.BinaryFields.create(binary=bytes(range(256)) * 500) - obj = await testmodels.BinaryFields.get(id=obj0.id) - self.assertEqual(obj.binary, bytes(range(256)) * 500) - self.assertEqual(obj.binary_null, None) - await obj.save() - obj2 = await testmodels.BinaryFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_values(self): - obj0 = await testmodels.BinaryFields.create( - binary=bytes(range(256)), binary_null=bytes(range(255, -1, -1)) - ) - values = await testmodels.BinaryFields.get(id=obj0.id).values("binary", "binary_null") - self.assertEqual(values["binary"], bytes(range(256))) - self.assertEqual(values["binary_null"], bytes(range(255, -1, -1))) - - async def test_values_list(self): - obj0 = await testmodels.BinaryFields.create(binary=bytes(range(256))) - values = await testmodels.BinaryFields.get(id=obj0.id).values_list("binary", flat=True) - self.assertEqual(values, bytes(range(256))) - - def test_unique_fail(self): - with self.assertRaisesRegex(ConfigurationError, "can't be indexed"): - BinaryField(unique=True) - - def test_index_fail(self): - with self.assertRaisesRegex(ConfigurationError, "can't be indexed"): - with self.assertWarnsRegex( - DeprecationWarning, "`index` is deprecated, please use `db_index` instead" - ): - BinaryField(index=True) - with self.assertRaisesRegex(ConfigurationError, "can't be indexed"): - BinaryField(db_index=True) +@pytest.mark.asyncio +async def test_empty(db): + with pytest.raises(IntegrityError): + await testmodels.BinaryFields.create() + + +@pytest.mark.asyncio +async def test_create(db): + obj0 = await testmodels.BinaryFields.create(binary=bytes(range(256)) * 500) + obj = await testmodels.BinaryFields.get(id=obj0.id) + assert obj.binary == bytes(range(256)) * 500 + assert obj.binary_null is None + await obj.save() + obj2 = await testmodels.BinaryFields.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_values(db): + obj0 = await testmodels.BinaryFields.create( + binary=bytes(range(256)), binary_null=bytes(range(255, -1, -1)) + ) + values = await testmodels.BinaryFields.get(id=obj0.id).values("binary", "binary_null") + assert values["binary"] == bytes(range(256)) + assert values["binary_null"] == bytes(range(255, -1, -1)) + + +@pytest.mark.asyncio +async def test_values_list(db): + obj0 = await testmodels.BinaryFields.create(binary=bytes(range(256))) + values = await testmodels.BinaryFields.get(id=obj0.id).values_list("binary", flat=True) + assert values == bytes(range(256)) + + +def test_unique_fail(): + with pytest.raises(ConfigurationError, match="can't be indexed"): + BinaryField(unique=True) + + +def test_index_fail(): + with pytest.warns( + DeprecationWarning, match="`index` is deprecated, please use `db_index` instead" + ): + with pytest.raises(ConfigurationError, match="can't be indexed"): + BinaryField(index=True) + with pytest.raises(ConfigurationError, match="can't be indexed"): + BinaryField(db_index=True) diff --git a/tests/fields/test_bool.py b/tests/fields/test_bool.py index b927c68f2..661486844 100644 --- a/tests/fields/test_bool.py +++ b/tests/fields/test_bool.py @@ -1,35 +1,44 @@ +import pytest + from tests import testmodels -from tortoise.contrib import test from tortoise.exceptions import IntegrityError -class TestBooleanFields(test.TestCase): - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels.BooleanFields.create() - - async def test_create(self): - obj0 = await testmodels.BooleanFields.create(boolean=True) - obj = await testmodels.BooleanFields.get(id=obj0.id) - self.assertIs(obj.boolean, True) - self.assertIs(obj.boolean_null, None) - await obj.save() - obj2 = await testmodels.BooleanFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_update(self): - obj0 = await testmodels.BooleanFields.create(boolean=False) - await testmodels.BooleanFields.filter(id=obj0.id).update(boolean=False) - obj = await testmodels.BooleanFields.get(id=obj0.id) - self.assertIs(obj.boolean, False) - self.assertIs(obj.boolean_null, None) - - async def test_values(self): - obj0 = await testmodels.BooleanFields.create(boolean=True) - values = await testmodels.BooleanFields.get(id=obj0.id).values("boolean") - self.assertIs(values["boolean"], True) - - async def test_values_list(self): - obj0 = await testmodels.BooleanFields.create(boolean=True) - values = await testmodels.BooleanFields.get(id=obj0.id).values_list("boolean", flat=True) - self.assertIs(values, True) +@pytest.mark.asyncio +async def test_empty(db): + with pytest.raises(IntegrityError): + await testmodels.BooleanFields.create() + + +@pytest.mark.asyncio +async def test_create(db): + obj0 = await testmodels.BooleanFields.create(boolean=True) + obj = await testmodels.BooleanFields.get(id=obj0.id) + assert obj.boolean is True + assert obj.boolean_null is None + await obj.save() + obj2 = await testmodels.BooleanFields.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_update(db): + obj0 = await testmodels.BooleanFields.create(boolean=False) + await testmodels.BooleanFields.filter(id=obj0.id).update(boolean=False) + obj = await testmodels.BooleanFields.get(id=obj0.id) + assert obj.boolean is False + assert obj.boolean_null is None + + +@pytest.mark.asyncio +async def test_values(db): + obj0 = await testmodels.BooleanFields.create(boolean=True) + values = await testmodels.BooleanFields.get(id=obj0.id).values("boolean") + assert values["boolean"] is True + + +@pytest.mark.asyncio +async def test_values_list(db): + obj0 = await testmodels.BooleanFields.create(boolean=True) + values = await testmodels.BooleanFields.get(id=obj0.id).values_list("boolean", flat=True) + assert values is True diff --git a/tests/fields/test_char.py b/tests/fields/test_char.py index 55dc2e599..02ce6532f 100644 --- a/tests/fields/test_char.py +++ b/tests/fields/test_char.py @@ -1,51 +1,62 @@ +import pytest + from tests import testmodels from tortoise import fields -from tortoise.contrib import test from tortoise.exceptions import ConfigurationError, ValidationError -class TestCharFields(test.TestCase): - def test_max_length_missing(self): - with self.assertRaisesRegex( - TypeError, "missing 1 required positional argument: 'max_length'" - ): - fields.CharField() # pylint: disable=E1120 - - def test_max_length_bad(self): - with self.assertRaisesRegex(ConfigurationError, "'max_length' must be >= 1"): - fields.CharField(max_length=0) - - async def test_empty(self): - with self.assertRaises(ValidationError): - await testmodels.CharFields.create() - - async def test_create(self): - obj0 = await testmodels.CharFields.create(char="moo") - obj = await testmodels.CharFields.get(id=obj0.id) - self.assertEqual(obj.char, "moo") - self.assertEqual(obj.char_null, None) - await obj.save() - obj2 = await testmodels.CharFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_update(self): - obj0 = await testmodels.CharFields.create(char="moo") - await testmodels.CharFields.filter(id=obj0.id).update(char="ba'a") - obj = await testmodels.CharFields.get(id=obj0.id) - self.assertEqual(obj.char, "ba'a") - self.assertEqual(obj.char_null, None) - - async def test_cast(self): - obj0 = await testmodels.CharFields.create(char=33) - obj = await testmodels.CharFields.get(id=obj0.id) - self.assertEqual(obj.char, "33") - - async def test_values(self): - obj0 = await testmodels.CharFields.create(char="moo") - values = await testmodels.CharFields.get(id=obj0.id).values("char") - self.assertEqual(values["char"], "moo") - - async def test_values_list(self): - obj0 = await testmodels.CharFields.create(char="moo") - values = await testmodels.CharFields.get(id=obj0.id).values_list("char", flat=True) - self.assertEqual(values, "moo") +def test_max_length_missing(): + with pytest.raises(TypeError, match="missing 1 required positional argument: 'max_length'"): + fields.CharField() # pylint: disable=E1120 + + +def test_max_length_bad(): + with pytest.raises(ConfigurationError, match="'max_length' must be >= 1"): + fields.CharField(max_length=0) + + +@pytest.mark.asyncio +async def test_empty(db): + with pytest.raises(ValidationError): + await testmodels.CharFields.create() + + +@pytest.mark.asyncio +async def test_create(db): + obj0 = await testmodels.CharFields.create(char="moo") + obj = await testmodels.CharFields.get(id=obj0.id) + assert obj.char == "moo" + assert obj.char_null is None + await obj.save() + obj2 = await testmodels.CharFields.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_update(db): + obj0 = await testmodels.CharFields.create(char="moo") + await testmodels.CharFields.filter(id=obj0.id).update(char="ba'a") + obj = await testmodels.CharFields.get(id=obj0.id) + assert obj.char == "ba'a" + assert obj.char_null is None + + +@pytest.mark.asyncio +async def test_cast(db): + obj0 = await testmodels.CharFields.create(char=33) + obj = await testmodels.CharFields.get(id=obj0.id) + assert obj.char == "33" + + +@pytest.mark.asyncio +async def test_values(db): + obj0 = await testmodels.CharFields.create(char="moo") + values = await testmodels.CharFields.get(id=obj0.id).values("char") + assert values["char"] == "moo" + + +@pytest.mark.asyncio +async def test_values_list(db): + obj0 = await testmodels.CharFields.create(char="moo") + values = await testmodels.CharFields.get(id=obj0.id).values_list("char", flat=True) + assert values == "moo" diff --git a/tests/fields/test_common.py b/tests/fields/test_common.py index dd450479f..932f5c65f 100644 --- a/tests/fields/test_common.py +++ b/tests/fields/test_common.py @@ -1,19 +1,29 @@ +import pytest + from tortoise import fields -from tortoise.contrib import test -class TestRequired(test.SimpleTestCase): - async def test_required_by_default(self): - self.assertTrue(fields.Field().required) +# Tests for field.required property - no database access needed +@pytest.mark.asyncio +async def test_required_by_default(): + assert fields.Field().required is True + + +@pytest.mark.asyncio +async def test_if_generated_then_not_required(): + assert fields.Field(generated=True).required is False + + +@pytest.mark.asyncio +async def test_if_null_then_not_required(): + assert fields.Field(null=True).required is False - async def test_if_generated_then_not_required(self): - self.assertFalse(fields.Field(generated=True).required) - async def test_if_null_then_not_required(self): - self.assertFalse(fields.Field(null=True).required) +@pytest.mark.asyncio +async def test_if_has_non_null_default_then_not_required(): + assert fields.TextField(default="").required is False - async def test_if_has_non_null_default_then_not_required(self): - self.assertFalse(fields.TextField(default="").required) - async def test_if_null_default_then_required(self): - self.assertTrue(fields.TextField(default=None).required) +@pytest.mark.asyncio +async def test_if_null_default_then_required(): + assert fields.TextField(default=None).required is True diff --git a/tests/fields/test_db_index.py b/tests/fields/test_db_index.py index 82d69510e..5d653dcde 100644 --- a/tests/fields/test_db_index.py +++ b/tests/fields/test_db_index.py @@ -2,11 +2,11 @@ from typing import Any +import pytest from pypika_tortoise.terms import Field from tests.testmodels import ModelWithIndexes from tortoise import fields -from tortoise.contrib import test from tortoise.exceptions import ConfigurationError from tortoise.indexes import Index @@ -17,94 +17,119 @@ def __init__(self, *args, **kw): self._foo = "" -class TestIndexHashEqualRepr(test.SimpleTestCase): - def test_index_eq(self): - assert Index(fields=("id",)) == Index(fields=("id",)) - assert CustomIndex(fields=("id",)) == CustomIndex(fields=("id",)) - assert Index(fields=("id", "name")) == Index(fields=["id", "name"]) - - assert Index(fields=("id", "name")) != Index(fields=("name", "id")) - assert Index(fields=("id",)) != Index(fields=("name",)) - assert CustomIndex(fields=("id",)) != Index(fields=("id",)) - - def test_index_hash(self): - assert hash(Index(fields=("id",))) == hash(Index(fields=("id",))) - assert hash(Index(fields=("id", "name"))) == hash(Index(fields=["id", "name"])) - assert hash(CustomIndex(fields=("id", "name"))) == hash(CustomIndex(fields=["id", "name"])) - - assert hash(Index(fields=("id", "name"))) != hash(Index(fields=["name", "id"])) - assert hash(Index(fields=("id",))) != hash(Index(fields=("name",))) - - indexes = {Index(fields=("id",))} - indexes.add(Index(fields=("id",))) - assert len(indexes) == 1 - indexes.add(CustomIndex(fields=("id",))) - assert len(indexes) == 2 - indexes.add(Index(fields=("name",))) - assert len(indexes) == 3 - - def test_index_repr(self): - assert repr(Index(fields=("id",))) == "Index(fields=['id'])" - assert repr(Index(fields=("id", "name"))) == "Index(fields=['id', 'name'])" - assert repr(Index(fields=("id",), name="MyIndex")) == "Index(fields=['id'], name='MyIndex')" - assert repr(Index(Field("id"))) == f"Index({str(Field('id'))})" - assert repr(Index(Field("a"), name="Id")) == f"Index({str(Field('a'))}, name='Id')" - with self.assertRaises(ConfigurationError): - Index(Field("id"), fields=("name",)) - - -class TestIndexAlias(test.TestCase): - Field: Any = fields.IntField - - def test_index_alias(self) -> None: - kwargs: dict = getattr(self, "init_kwargs", {}) - with self.assertWarnsRegex( - DeprecationWarning, "`index` is deprecated, please use `db_index` instead" - ): - f = self.Field(index=True, **kwargs) - assert f.index is True - with self.assertWarnsRegex( - DeprecationWarning, "`index` is deprecated, please use `db_index` instead" - ): - f = self.Field(index=False, **kwargs) - assert f.index is False - f = self.Field(db_index=True, **kwargs) - assert f.index is True - f = self.Field(db_index=True, index=True, **kwargs) - assert f.index is True - f = self.Field(db_index=False, **kwargs) - assert f.index is False - f = self.Field(db_index=False, index=False, **kwargs) - assert f.index is False - with self.assertRaisesRegex(ConfigurationError, "can't set both db_index and index"): - self.Field(db_index=False, index=True, **kwargs) - with self.assertRaisesRegex(ConfigurationError, "can't set both db_index and index"): - self.Field(db_index=True, index=False, **kwargs) - - -class TestIndexAliasSmallInt(TestIndexAlias): - Field = fields.SmallIntField - - -class TestIndexAliasBigInt(TestIndexAlias): - Field = fields.BigIntField - - -class TestIndexAliasUUID(TestIndexAlias): - Field = fields.UUIDField - - -class TestIndexAliasChar(TestIndexAlias): - Field = fields.CharField - init_kwargs = {"max_length": 10} - - -class TestModelWithIndexes(test.TestCase): - def test_meta(self): - self.assertEqual( - ModelWithIndexes._meta.indexes, - [Index(fields=("f1", "f2")), Index(fields=("f3",), name="model_with_indexes__f3")], - ) - self.assertTrue(ModelWithIndexes._meta.fields_map["id"].index) - self.assertTrue(ModelWithIndexes._meta.fields_map["indexed"].index) - self.assertTrue(ModelWithIndexes._meta.fields_map["unique_indexed"].unique) +# ============================================================================ +# Tests for Index hash, equality, and repr (no database needed) +# ============================================================================ + + +def test_index_eq(): + assert Index(fields=("id",)) == Index(fields=("id",)) + assert CustomIndex(fields=("id",)) == CustomIndex(fields=("id",)) + assert Index(fields=("id", "name")) == Index(fields=["id", "name"]) + + assert Index(fields=("id", "name")) != Index(fields=("name", "id")) + assert Index(fields=("id",)) != Index(fields=("name",)) + assert CustomIndex(fields=("id",)) != Index(fields=("id",)) + + +def test_index_hash(): + assert hash(Index(fields=("id",))) == hash(Index(fields=("id",))) + assert hash(Index(fields=("id", "name"))) == hash(Index(fields=["id", "name"])) + assert hash(CustomIndex(fields=("id", "name"))) == hash(CustomIndex(fields=["id", "name"])) + + assert hash(Index(fields=("id", "name"))) != hash(Index(fields=["name", "id"])) + assert hash(Index(fields=("id",))) != hash(Index(fields=("name",))) + + indexes = {Index(fields=("id",))} + indexes.add(Index(fields=("id",))) + assert len(indexes) == 1 + indexes.add(CustomIndex(fields=("id",))) + assert len(indexes) == 2 + indexes.add(Index(fields=("name",))) + assert len(indexes) == 3 + + +def test_index_repr(): + assert repr(Index(fields=("id",))) == "Index(fields=['id'])" + assert repr(Index(fields=("id", "name"))) == "Index(fields=['id', 'name'])" + assert repr(Index(fields=("id",), name="MyIndex")) == "Index(fields=['id'], name='MyIndex')" + assert repr(Index(Field("id"))) == f"Index({str(Field('id'))})" + assert repr(Index(Field("a"), name="Id")) == f"Index({str(Field('a'))}, name='Id')" + with pytest.raises(ConfigurationError): + Index(Field("id"), fields=("name",)) + + +# ============================================================================ +# Tests for index/db_index field alias (no database needed) +# ============================================================================ + + +def _test_index_alias_for_field(field_class: Any, init_kwargs: dict | None = None): + """Helper function to test index alias behavior for a given field class.""" + kwargs: dict = init_kwargs or {} + + with pytest.warns( + DeprecationWarning, match="`index` is deprecated, please use `db_index` instead" + ): + f = field_class(index=True, **kwargs) + assert f.index is True + + with pytest.warns( + DeprecationWarning, match="`index` is deprecated, please use `db_index` instead" + ): + f = field_class(index=False, **kwargs) + assert f.index is False + + f = field_class(db_index=True, **kwargs) + assert f.index is True + + f = field_class(db_index=True, index=True, **kwargs) + assert f.index is True + + f = field_class(db_index=False, **kwargs) + assert f.index is False + + f = field_class(db_index=False, index=False, **kwargs) + assert f.index is False + + with pytest.raises(ConfigurationError, match="can't set both db_index and index"): + field_class(db_index=False, index=True, **kwargs) + + with pytest.raises(ConfigurationError, match="can't set both db_index and index"): + field_class(db_index=True, index=False, **kwargs) + + +def test_index_alias_int_field(): + _test_index_alias_for_field(fields.IntField) + + +def test_index_alias_small_int_field(): + _test_index_alias_for_field(fields.SmallIntField) + + +def test_index_alias_big_int_field(): + _test_index_alias_for_field(fields.BigIntField) + + +def test_index_alias_uuid_field(): + _test_index_alias_for_field(fields.UUIDField) + + +def test_index_alias_char_field(): + _test_index_alias_for_field(fields.CharField, init_kwargs={"max_length": 10}) + + +# ============================================================================ +# Tests for ModelWithIndexes metadata (requires database fixture) +# ============================================================================ + + +@pytest.mark.asyncio +async def test_model_with_indexes_meta(db): + assert ModelWithIndexes._meta.indexes == [ + Index(fields=("f1", "f2")), + Index(fields=("f3",), name="model_with_indexes__f3"), + ] + assert ModelWithIndexes._meta.fields_map["id"].index + assert ModelWithIndexes._meta.fields_map["indexed"].index + assert ModelWithIndexes._meta.fields_map["unique_indexed"].unique diff --git a/tests/fields/test_decimal.py b/tests/fields/test_decimal.py index 8fc1c614f..167b61874 100644 --- a/tests/fields/test_decimal.py +++ b/tests/fields/test_decimal.py @@ -1,336 +1,311 @@ from decimal import Decimal +import pytest + from tests import testmodels from tortoise import fields -from tortoise.contrib import test from tortoise.exceptions import ConfigurationError, FieldError, IntegrityError from tortoise.expressions import F from tortoise.functions import Avg, Max, Sum -class TestDecimalFields(test.TestCase): - def test_max_digits_empty(self): - with self.assertRaisesRegex( - TypeError, - "missing 2 required positional arguments: 'max_digits' and 'decimal_places'", - ): - fields.DecimalField() # pylint: disable=E1120 - - def test_decimal_places_empty(self): - with self.assertRaisesRegex( - TypeError, "missing 1 required positional argument: 'decimal_places'" - ): - fields.DecimalField(max_digits=1) # pylint: disable=E1120 - - def test_max_fields_bad(self): - with self.assertRaisesRegex(ConfigurationError, "'max_digits' must be >= 1"): - fields.DecimalField(max_digits=0, decimal_places=2) - - def test_decimal_places_bad(self): - with self.assertRaisesRegex(ConfigurationError, "'decimal_places' must be >= 0"): - fields.DecimalField(max_digits=2, decimal_places=-1) - - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels.DecimalFields.create() - - async def test_create(self): - obj0 = await testmodels.DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7) - obj = await testmodels.DecimalFields.get(id=obj0.id) - self.assertEqual(obj.decimal, Decimal("1.2346")) - self.assertEqual(obj.decimal_nodec, 19) - self.assertEqual(obj.decimal_null, None) - await obj.save() - obj2 = await testmodels.DecimalFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_update(self): - obj0 = await testmodels.DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7) - await testmodels.DecimalFields.filter(id=obj0.id).update(decimal=Decimal("2.345")) - obj = await testmodels.DecimalFields.get(id=obj0.id) - self.assertEqual(obj.decimal, Decimal("2.345")) - self.assertEqual(obj.decimal_nodec, 19) - self.assertEqual(obj.decimal_null, None) - - async def test_filter(self): - obj0 = await testmodels.DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7) - obj = await testmodels.DecimalFields.filter(decimal=Decimal("1.2346")).first() - self.assertEqual(obj, obj0) - obj = ( - await testmodels.DecimalFields.annotate(d=F("decimal")) - .filter(d=Decimal("1.2346")) - .first() - ) - self.assertEqual(obj, obj0) - objs = await testmodels.DecimalFields.filter(decimal_nodec__gt=2).all() - self.assertIn(obj, objs) - objs = await testmodels.DecimalFields.filter(decimal_nodec__lt=100).all() - self.assertIn(obj, objs) - - async def test_f_expression_update(self): - obj0 = await testmodels.DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7) - await obj0.filter(id=obj0.id).update(decimal=F("decimal") + Decimal("1")) - obj1 = await testmodels.DecimalFields.get(id=obj0.id) - self.assertEqual(obj1.decimal, Decimal("2.2346")) - await obj0.filter(id=obj0.id).update(decimal=Decimal("1") - F("decimal")) - obj1 = await testmodels.DecimalFields.get(id=obj0.id) - self.assertEqual(obj1.decimal, Decimal("-1.2346")) - - async def test_values(self): - obj0 = await testmodels.DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7) - values = await testmodels.DecimalFields.get(id=obj0.id).values("decimal", "decimal_nodec") - self.assertEqual(values["decimal"], Decimal("1.2346")) - self.assertEqual(values["decimal_nodec"], 19) - - async def test_values_list(self): - obj0 = await testmodels.DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7) - values = await testmodels.DecimalFields.get(id=obj0.id).values_list( - "decimal", "decimal_nodec" - ) - self.assertEqual(list(values), [Decimal("1.2346"), 19]) - - async def test_order_by(self): - await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) - values = ( - await testmodels.DecimalFields.all() - .order_by("decimal") - .values_list("decimal", flat=True) - ) - self.assertEqual(values, [Decimal("0"), Decimal("9.99"), Decimal("27.27")]) - - async def test_aggregate_sum(self): - await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) - values = ( - await testmodels.DecimalFields.all() - .annotate(sum_decimal=Sum("decimal")) +def test_max_digits_empty(): + with pytest.raises( + TypeError, + match="missing 2 required positional arguments: 'max_digits' and 'decimal_places'", + ): + fields.DecimalField() # pylint: disable=E1120 + + +def test_decimal_places_empty(): + with pytest.raises(TypeError, match="missing 1 required positional argument: 'decimal_places'"): + fields.DecimalField(max_digits=1) # pylint: disable=E1120 + + +def test_max_fields_bad(): + with pytest.raises(ConfigurationError, match="'max_digits' must be >= 1"): + fields.DecimalField(max_digits=0, decimal_places=2) + + +def test_decimal_places_bad(): + with pytest.raises(ConfigurationError, match="'decimal_places' must be >= 0"): + fields.DecimalField(max_digits=2, decimal_places=-1) + + +@pytest.mark.asyncio +async def test_empty(db): + with pytest.raises(IntegrityError): + await testmodels.DecimalFields.create() + + +@pytest.mark.asyncio +async def test_create(db): + obj0 = await testmodels.DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7) + obj = await testmodels.DecimalFields.get(id=obj0.id) + assert obj.decimal == Decimal("1.2346") + assert obj.decimal_nodec == 19 + assert obj.decimal_null is None + await obj.save() + obj2 = await testmodels.DecimalFields.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_update(db): + obj0 = await testmodels.DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7) + await testmodels.DecimalFields.filter(id=obj0.id).update(decimal=Decimal("2.345")) + obj = await testmodels.DecimalFields.get(id=obj0.id) + assert obj.decimal == Decimal("2.345") + assert obj.decimal_nodec == 19 + assert obj.decimal_null is None + + +@pytest.mark.asyncio +async def test_filter(db): + obj0 = await testmodels.DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7) + obj = await testmodels.DecimalFields.filter(decimal=Decimal("1.2346")).first() + assert obj == obj0 + obj = ( + await testmodels.DecimalFields.annotate(d=F("decimal")).filter(d=Decimal("1.2346")).first() + ) + assert obj == obj0 + objs = await testmodels.DecimalFields.filter(decimal_nodec__gt=2).all() + assert obj in objs + objs = await testmodels.DecimalFields.filter(decimal_nodec__lt=100).all() + assert obj in objs + + +@pytest.mark.asyncio +async def test_f_expression_update(db): + obj0 = await testmodels.DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7) + await obj0.filter(id=obj0.id).update(decimal=F("decimal") + Decimal("1")) + obj1 = await testmodels.DecimalFields.get(id=obj0.id) + assert obj1.decimal == Decimal("2.2346") + await obj0.filter(id=obj0.id).update(decimal=Decimal("1") - F("decimal")) + obj1 = await testmodels.DecimalFields.get(id=obj0.id) + assert obj1.decimal == Decimal("-1.2346") + + +@pytest.mark.asyncio +async def test_values(db): + obj0 = await testmodels.DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7) + values = await testmodels.DecimalFields.get(id=obj0.id).values("decimal", "decimal_nodec") + assert values["decimal"] == Decimal("1.2346") + assert values["decimal_nodec"] == 19 + + +@pytest.mark.asyncio +async def test_values_list(db): + obj0 = await testmodels.DecimalFields.create(decimal=Decimal("1.23456"), decimal_nodec=18.7) + values = await testmodels.DecimalFields.get(id=obj0.id).values_list("decimal", "decimal_nodec") + assert list(values) == [Decimal("1.2346"), 19] + + +@pytest.mark.asyncio +async def test_order_by(db): + await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) + values = ( + await testmodels.DecimalFields.all().order_by("decimal").values_list("decimal", flat=True) + ) + assert values == [Decimal("0"), Decimal("9.99"), Decimal("27.27")] + + +@pytest.mark.asyncio +async def test_aggregate_sum(db): + await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) + values = ( + await testmodels.DecimalFields.all() + .annotate(sum_decimal=Sum("decimal")) + .values("sum_decimal") + ) + assert values[0] == {"sum_decimal": Decimal("37.26")} + + +@pytest.mark.asyncio +async def test_aggregate_sum_with_f_expression(db): + await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) + values = ( + await testmodels.DecimalFields.all() + .annotate(sum_decimal=Sum(F("decimal"))) + .values("sum_decimal") + ) + assert values[0] == {"sum_decimal": Decimal("37.26")} + + values = ( + await testmodels.DecimalFields.all() + .annotate(sum_decimal=Sum(F("decimal") + 1)) + .values("sum_decimal") + ) + assert values[0] == {"sum_decimal": Decimal("40.26")} + + values = ( + await testmodels.DecimalFields.all() + .annotate(sum_decimal=Sum(F("decimal") + F("decimal"))) + .values("sum_decimal") + ) + assert values[0] == {"sum_decimal": Decimal("74.52")} + + values = ( + await testmodels.DecimalFields.all() + .annotate(sum_decimal=Sum(F("decimal") + F("decimal_nodec"))) + .values("sum_decimal") + ) + assert values[0] == {"sum_decimal": Decimal("4E+1")} + + values = ( + await testmodels.DecimalFields.all() + .annotate(sum_decimal=Sum(F("decimal") + F("decimal_null"))) + .values("sum_decimal") + ) + assert values[0] == {"sum_decimal": None} + + +@pytest.mark.asyncio +async def test_aggregate_sum_no_exist_field_with_f_expression(db): + with pytest.raises( + FieldError, + match="There is no non-virtual field not_exist on Model DecimalFields", + ): + await ( + testmodels.DecimalFields.all() + .annotate(sum_decimal=Sum(F("not_exist"))) .values("sum_decimal") ) - self.assertEqual( - values[0], - {"sum_decimal": Decimal("37.26")}, - ) - async def test_aggregate_sum_with_f_expression(self): - await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) - values = ( - await testmodels.DecimalFields.all() - .annotate(sum_decimal=Sum(F("decimal"))) - .values("sum_decimal") - ) - self.assertEqual( - values[0], - {"sum_decimal": Decimal("37.26")}, - ) - - values = ( - await testmodels.DecimalFields.all() - .annotate(sum_decimal=Sum(F("decimal") + 1)) - .values("sum_decimal") - ) - self.assertEqual( - values[0], - {"sum_decimal": Decimal("40.26")}, - ) - values = ( - await testmodels.DecimalFields.all() - .annotate(sum_decimal=Sum(F("decimal") + F("decimal"))) +@pytest.mark.asyncio +async def test_aggregate_sum_different_field_type_at_right_with_f_expression(db): + with pytest.raises( + FieldError, match="Cannot use arithmetic expression between different field type" + ): + await ( + testmodels.DecimalFields.all() + .annotate(sum_decimal=Sum(F("decimal") + F("id"))) .values("sum_decimal") ) - self.assertEqual( - values[0], - {"sum_decimal": Decimal("74.52")}, - ) - values = ( - await testmodels.DecimalFields.all() - .annotate(sum_decimal=Sum(F("decimal") + F("decimal_nodec"))) - .values("sum_decimal") - ) - self.assertEqual( - values[0], - {"sum_decimal": Decimal("4E+1")}, - ) - values = ( - await testmodels.DecimalFields.all() - .annotate(sum_decimal=Sum(F("decimal") + F("decimal_null"))) +@pytest.mark.asyncio +async def test_aggregate_sum_different_field_type_at_left_with_f_expression(db): + with pytest.raises( + FieldError, match="Cannot use arithmetic expression between different field type" + ): + await ( + testmodels.DecimalFields.all() + .annotate(sum_decimal=Sum(F("id") + F("decimal"))) .values("sum_decimal") ) - self.assertEqual( - values[0], - {"sum_decimal": None}, - ) - async def test_aggregate_sum_no_exist_field_with_f_expression(self): - with self.assertRaisesRegex( - FieldError, - "There is no non-virtual field not_exist on Model DecimalFields", - ): - await ( - testmodels.DecimalFields.all() - .annotate(sum_decimal=Sum(F("not_exist"))) - .values("sum_decimal") - ) - - async def test_aggregate_sum_different_field_type_at_right_with_f_expression(self): - with self.assertRaisesRegex( - FieldError, "Cannot use arithmetic expression between different field type" - ): - await ( - testmodels.DecimalFields.all() - .annotate(sum_decimal=Sum(F("decimal") + F("id"))) - .values("sum_decimal") - ) - - async def test_aggregate_sum_different_field_type_at_left_with_f_expression(self): - with self.assertRaisesRegex( - FieldError, "Cannot use arithmetic expression between different field type" - ): - await ( - testmodels.DecimalFields.all() - .annotate(sum_decimal=Sum(F("id") + F("decimal"))) - .values("sum_decimal") - ) - - async def test_aggregate_avg(self): - await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) - values = ( - await testmodels.DecimalFields.all() - .annotate(avg_decimal=Avg("decimal")) - .values("avg_decimal") - ) - self.assertEqual( - values[0], - {"avg_decimal": Decimal("12.42")}, - ) - async def test_aggregate_avg_with_f_expression(self): - await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) - values = ( - await testmodels.DecimalFields.all() - .annotate(avg_decimal=Avg(F("decimal"))) - .values("avg_decimal") - ) - self.assertEqual( - values[0], - {"avg_decimal": Decimal("12.42")}, - ) - - values = ( - await testmodels.DecimalFields.all() - .annotate(avg_decimal=Avg(F("decimal") + 1)) - .values("avg_decimal") - ) - self.assertEqual( - values[0], - {"avg_decimal": Decimal("13.42")}, - ) - - values = ( - await testmodels.DecimalFields.all() - .annotate(avg_decimal=Avg(F("decimal") + F("decimal"))) - .values("avg_decimal") - ) - self.assertEqual( - values[0], - {"avg_decimal": Decimal("24.84")}, - ) - - values = ( - await testmodels.DecimalFields.all() - .annotate(avg_decimal=Avg(F("decimal") + F("decimal_nodec"))) - .values("avg_decimal") - ) - self.assertEqual( - values[0], - {"avg_decimal": Decimal("13")}, - ) - - values = ( - await testmodels.DecimalFields.all() - .annotate(avg_decimal=Avg(F("decimal") + F("decimal_null"))) - .values("avg_decimal") - ) - self.assertEqual( - values[0], - {"avg_decimal": None}, - ) - - async def test_aggregate_max(self): - await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) - values = ( - await testmodels.DecimalFields.all() - .annotate(max_decimal=Max("decimal")) - .values("max_decimal") - ) - self.assertEqual( - values[0], - {"max_decimal": Decimal("27.27")}, - ) - - async def test_aggregate_max_with_f_expression(self): - await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) - await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) - values = ( - await testmodels.DecimalFields.all() - .annotate(max_decimal=Max(F("decimal"))) - .values("max_decimal") - ) - self.assertEqual( - values[0], - {"max_decimal": Decimal("27.27")}, - ) - - values = ( - await testmodels.DecimalFields.all() - .annotate(max_decimal=Max(F("decimal") + 1)) - .values("max_decimal") - ) - self.assertEqual( - values[0], - {"max_decimal": Decimal("28.27")}, - ) - - values = ( - await testmodels.DecimalFields.all() - .annotate(max_decimal=Max(F("decimal") + F("decimal"))) - .values("max_decimal") - ) - self.assertEqual( - values[0], - {"max_decimal": Decimal("54.54")}, - ) - - values = ( - await testmodels.DecimalFields.all() - .annotate(max_decimal=Max(F("decimal") + F("decimal_nodec"))) - .values("max_decimal") - ) - self.assertEqual( - values[0], - {"max_decimal": Decimal("28")}, - ) - - values = ( - await testmodels.DecimalFields.all() - .annotate(max_decimal=Max(F("decimal") + F("decimal_null"))) - .values("max_decimal") - ) - self.assertEqual( - values[0], - {"max_decimal": None}, - ) +@pytest.mark.asyncio +async def test_aggregate_avg(db): + await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) + values = ( + await testmodels.DecimalFields.all() + .annotate(avg_decimal=Avg("decimal")) + .values("avg_decimal") + ) + assert values[0] == {"avg_decimal": Decimal("12.42")} + + +@pytest.mark.asyncio +async def test_aggregate_avg_with_f_expression(db): + await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) + values = ( + await testmodels.DecimalFields.all() + .annotate(avg_decimal=Avg(F("decimal"))) + .values("avg_decimal") + ) + assert values[0] == {"avg_decimal": Decimal("12.42")} + + values = ( + await testmodels.DecimalFields.all() + .annotate(avg_decimal=Avg(F("decimal") + 1)) + .values("avg_decimal") + ) + assert values[0] == {"avg_decimal": Decimal("13.42")} + + values = ( + await testmodels.DecimalFields.all() + .annotate(avg_decimal=Avg(F("decimal") + F("decimal"))) + .values("avg_decimal") + ) + assert values[0] == {"avg_decimal": Decimal("24.84")} + + values = ( + await testmodels.DecimalFields.all() + .annotate(avg_decimal=Avg(F("decimal") + F("decimal_nodec"))) + .values("avg_decimal") + ) + assert values[0] == {"avg_decimal": Decimal("13")} + + values = ( + await testmodels.DecimalFields.all() + .annotate(avg_decimal=Avg(F("decimal") + F("decimal_null"))) + .values("avg_decimal") + ) + assert values[0] == {"avg_decimal": None} + + +@pytest.mark.asyncio +async def test_aggregate_max(db): + await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) + values = ( + await testmodels.DecimalFields.all() + .annotate(max_decimal=Max("decimal")) + .values("max_decimal") + ) + assert values[0] == {"max_decimal": Decimal("27.27")} + + +@pytest.mark.asyncio +async def test_aggregate_max_with_f_expression(db): + await testmodels.DecimalFields.create(decimal=Decimal("0"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("9.99"), decimal_nodec=1) + await testmodels.DecimalFields.create(decimal=Decimal("27.27"), decimal_nodec=1) + values = ( + await testmodels.DecimalFields.all() + .annotate(max_decimal=Max(F("decimal"))) + .values("max_decimal") + ) + assert values[0] == {"max_decimal": Decimal("27.27")} + + values = ( + await testmodels.DecimalFields.all() + .annotate(max_decimal=Max(F("decimal") + 1)) + .values("max_decimal") + ) + assert values[0] == {"max_decimal": Decimal("28.27")} + + values = ( + await testmodels.DecimalFields.all() + .annotate(max_decimal=Max(F("decimal") + F("decimal"))) + .values("max_decimal") + ) + assert values[0] == {"max_decimal": Decimal("54.54")} + + values = ( + await testmodels.DecimalFields.all() + .annotate(max_decimal=Max(F("decimal") + F("decimal_nodec"))) + .values("max_decimal") + ) + assert values[0] == {"max_decimal": Decimal("28")} + + values = ( + await testmodels.DecimalFields.all() + .annotate(max_decimal=Max(F("decimal") + F("decimal_null"))) + .values("max_decimal") + ) + assert values[0] == {"max_decimal": None} diff --git a/tests/fields/test_enum.py b/tests/fields/test_enum.py index 6f02c5759..e89816703 100644 --- a/tests/fields/test_enum.py +++ b/tests/fields/test_enum.py @@ -1,7 +1,8 @@ from enum import IntEnum +import pytest + from tests import testmodels -from tortoise.contrib import test from tortoise.exceptions import ConfigurationError, IntegrityError from tortoise.fields import CharEnumField, IntEnumField @@ -24,173 +25,204 @@ class BadIntEnumIfGenerated(IntEnum): system_administration = 3 -class TestIntEnumFields(test.TestCase): - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels.EnumFields.create() - - async def test_create(self): - obj0 = await testmodels.EnumFields.create(service=testmodels.Service.system_administration) - self.assertIsInstance(obj0.service, testmodels.Service) - obj = await testmodels.EnumFields.get(id=obj0.id) - self.assertIsInstance(obj.service, testmodels.Service) - self.assertEqual(obj.service, testmodels.Service.system_administration) - await obj.save() - obj2 = await testmodels.EnumFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - await obj.delete() - obj = await testmodels.EnumFields.filter(id=obj0.id).first() - self.assertEqual(obj, None) - - obj3 = await testmodels.EnumFields.create(service=3) - self.assertIsInstance(obj3.service, testmodels.Service) - with self.assertRaises(ValueError): - await testmodels.EnumFields.create(service=4) - - async def test_update(self): - obj0 = await testmodels.EnumFields.create(service=testmodels.Service.system_administration) - await testmodels.EnumFields.filter(id=obj0.id).update( - service=testmodels.Service.database_design - ) - obj = await testmodels.EnumFields.get(id=obj0.id) - self.assertEqual(obj.service, testmodels.Service.database_design) - - await testmodels.EnumFields.filter(id=obj0.id).update(service=2) - obj = await testmodels.EnumFields.get(id=obj0.id) - self.assertEqual(obj.service, testmodels.Service.database_design) - with self.assertRaises(ValueError): - await testmodels.EnumFields.filter(id=obj0.id).update(service=4) - - async def test_values(self): - obj0 = await testmodels.EnumFields.create(service=testmodels.Service.system_administration) - values = await testmodels.EnumFields.get(id=obj0.id).values("service") - self.assertEqual(values["service"], testmodels.Service.system_administration) - - obj1 = await testmodels.EnumFields.create(service=3) - values = await testmodels.EnumFields.get(id=obj1.id).values("service") - self.assertEqual(values["service"], testmodels.Service.system_administration) - - async def test_values_list(self): - obj0 = await testmodels.EnumFields.create(service=testmodels.Service.system_administration) - values = await testmodels.EnumFields.get(id=obj0.id).values_list("service", flat=True) - self.assertEqual(values, testmodels.Service.system_administration) - - obj1 = await testmodels.EnumFields.create(service=3) - values = await testmodels.EnumFields.get(id=obj1.id).values_list("service", flat=True) - self.assertEqual(values, testmodels.Service.system_administration) - - def test_char_fails(self): - with self.assertRaisesRegex( - ConfigurationError, "IntEnumField only supports integer enums!" - ): - IntEnumField(testmodels.Currency) - - def test_range1_fails(self): - with self.assertRaisesRegex( - ConfigurationError, "The valid range of IntEnumField's values is -32768..32767!" - ): - IntEnumField(BadIntEnum1) - - def test_range2_fails(self): - with self.assertRaisesRegex( - ConfigurationError, "The valid range of IntEnumField's values is -32768..32767!" - ): - IntEnumField(BadIntEnum2) - - def test_range3_generated_fails(self): - with self.assertRaisesRegex( - ConfigurationError, "The valid range of IntEnumField's values is 1..32767!" - ): - IntEnumField(BadIntEnumIfGenerated, generated=True) - - def test_range3_manual(self): - fld = IntEnumField(BadIntEnumIfGenerated) - self.assertIs(fld.enum_type, BadIntEnumIfGenerated) - - def test_auto_description(self): - fld = IntEnumField(testmodels.Service) - self.assertEqual( - fld.description, "python_programming: 1\ndatabase_design: 2\nsystem_administration: 3" - ) +# ============================================================================ +# TestIntEnumFields +# ============================================================================ - def test_manual_description(self): - fld = IntEnumField(testmodels.Service, description="foo") - self.assertEqual(fld.description, "foo") +@pytest.mark.asyncio +async def test_int_enum_empty(db): + with pytest.raises(IntegrityError): + await testmodels.EnumFields.create() -class TestCharEnumFields(test.TestCase): - async def test_create(self): - obj0 = await testmodels.EnumFields.create(service=testmodels.Service.system_administration) - self.assertIsInstance(obj0.currency, testmodels.Currency) - obj = await testmodels.EnumFields.get(id=obj0.id) - self.assertIsInstance(obj.currency, testmodels.Currency) - self.assertEqual(obj.currency, testmodels.Currency.HUF) - await obj.save() - obj2 = await testmodels.EnumFields.get(id=obj.id) - self.assertEqual(obj, obj2) - await obj.delete() - obj = await testmodels.EnumFields.filter(id=obj0.id).first() - self.assertEqual(obj, None) +@pytest.mark.asyncio +async def test_int_enum_create(db): + obj0 = await testmodels.EnumFields.create(service=testmodels.Service.system_administration) + assert isinstance(obj0.service, testmodels.Service) + obj = await testmodels.EnumFields.get(id=obj0.id) + assert isinstance(obj.service, testmodels.Service) + assert obj.service == testmodels.Service.system_administration + await obj.save() + obj2 = await testmodels.EnumFields.get(id=obj.id) + assert obj == obj2 + + await obj.delete() + obj = await testmodels.EnumFields.filter(id=obj0.id).first() + assert obj is None + + obj3 = await testmodels.EnumFields.create(service=3) + assert isinstance(obj3.service, testmodels.Service) + with pytest.raises(ValueError): + await testmodels.EnumFields.create(service=4) + + +@pytest.mark.asyncio +async def test_int_enum_update(db): + obj0 = await testmodels.EnumFields.create(service=testmodels.Service.system_administration) + await testmodels.EnumFields.filter(id=obj0.id).update( + service=testmodels.Service.database_design + ) + obj = await testmodels.EnumFields.get(id=obj0.id) + assert obj.service == testmodels.Service.database_design + + await testmodels.EnumFields.filter(id=obj0.id).update(service=2) + obj = await testmodels.EnumFields.get(id=obj0.id) + assert obj.service == testmodels.Service.database_design + with pytest.raises(ValueError): + await testmodels.EnumFields.filter(id=obj0.id).update(service=4) + + +@pytest.mark.asyncio +async def test_int_enum_values(db): + obj0 = await testmodels.EnumFields.create(service=testmodels.Service.system_administration) + values = await testmodels.EnumFields.get(id=obj0.id).values("service") + assert values["service"] == testmodels.Service.system_administration + + obj1 = await testmodels.EnumFields.create(service=3) + values = await testmodels.EnumFields.get(id=obj1.id).values("service") + assert values["service"] == testmodels.Service.system_administration + + +@pytest.mark.asyncio +async def test_int_enum_values_list(db): + obj0 = await testmodels.EnumFields.create(service=testmodels.Service.system_administration) + values = await testmodels.EnumFields.get(id=obj0.id).values_list("service", flat=True) + assert values == testmodels.Service.system_administration + + obj1 = await testmodels.EnumFields.create(service=3) + values = await testmodels.EnumFields.get(id=obj1.id).values_list("service", flat=True) + assert values == testmodels.Service.system_administration + + +def test_int_enum_char_fails(): + with pytest.raises(ConfigurationError, match="IntEnumField only supports integer enums!"): + IntEnumField(testmodels.Currency) + + +def test_int_enum_range1_fails(): + with pytest.raises( + ConfigurationError, match="The valid range of IntEnumField's values is -32768..32767!" + ): + IntEnumField(BadIntEnum1) + + +def test_int_enum_range2_fails(): + with pytest.raises( + ConfigurationError, match="The valid range of IntEnumField's values is -32768..32767!" + ): + IntEnumField(BadIntEnum2) - obj0 = await testmodels.EnumFields.create( - service=testmodels.Service.system_administration, currency="USD" - ) - self.assertIsInstance(obj0.currency, testmodels.Currency) - with self.assertRaises(ValueError): - await testmodels.EnumFields.create( - service=testmodels.Service.system_administration, currency="XXX" - ) - - async def test_update(self): - obj0 = await testmodels.EnumFields.create( - service=testmodels.Service.system_administration, currency=testmodels.Currency.HUF - ) - await testmodels.EnumFields.filter(id=obj0.id).update(currency=testmodels.Currency.EUR) - obj = await testmodels.EnumFields.get(id=obj0.id) - self.assertEqual(obj.currency, testmodels.Currency.EUR) - - await testmodels.EnumFields.filter(id=obj0.id).update(currency="USD") - obj = await testmodels.EnumFields.get(id=obj0.id) - self.assertEqual(obj.currency, testmodels.Currency.USD) - with self.assertRaises(ValueError): - await testmodels.EnumFields.filter(id=obj0.id).update(currency="XXX") - - async def test_values(self): - obj0 = await testmodels.EnumFields.create( - service=testmodels.Service.system_administration, currency=testmodels.Currency.EUR - ) - values = await testmodels.EnumFields.get(id=obj0.id).values("currency") - self.assertEqual(values["currency"], testmodels.Currency.EUR) - obj1 = await testmodels.EnumFields.create(service=3, currency="EUR") - values = await testmodels.EnumFields.get(id=obj1.id).values("currency") - self.assertEqual(values["currency"], testmodels.Currency.EUR) +def test_int_enum_range3_generated_fails(): + with pytest.raises( + ConfigurationError, match="The valid range of IntEnumField's values is 1..32767!" + ): + IntEnumField(BadIntEnumIfGenerated, generated=True) - async def test_values_list(self): - obj0 = await testmodels.EnumFields.create( - service=testmodels.Service.system_administration, currency=testmodels.Currency.EUR + +def test_int_enum_range3_manual(): + fld = IntEnumField(BadIntEnumIfGenerated) + assert fld.enum_type is BadIntEnumIfGenerated + + +def test_int_enum_auto_description(): + fld = IntEnumField(testmodels.Service) + assert fld.description == "python_programming: 1\ndatabase_design: 2\nsystem_administration: 3" + + +def test_int_enum_manual_description(): + fld = IntEnumField(testmodels.Service, description="foo") + assert fld.description == "foo" + + +# ============================================================================ +# TestCharEnumFields +# ============================================================================ + + +@pytest.mark.asyncio +async def test_char_enum_create(db): + obj0 = await testmodels.EnumFields.create(service=testmodels.Service.system_administration) + assert isinstance(obj0.currency, testmodels.Currency) + obj = await testmodels.EnumFields.get(id=obj0.id) + assert isinstance(obj.currency, testmodels.Currency) + assert obj.currency == testmodels.Currency.HUF + await obj.save() + obj2 = await testmodels.EnumFields.get(id=obj.id) + assert obj == obj2 + + await obj.delete() + obj = await testmodels.EnumFields.filter(id=obj0.id).first() + assert obj is None + + obj0 = await testmodels.EnumFields.create( + service=testmodels.Service.system_administration, currency="USD" + ) + assert isinstance(obj0.currency, testmodels.Currency) + with pytest.raises(ValueError): + await testmodels.EnumFields.create( + service=testmodels.Service.system_administration, currency="XXX" ) - values = await testmodels.EnumFields.get(id=obj0.id).values_list("currency", flat=True) - self.assertEqual(values, testmodels.Currency.EUR) - obj1 = await testmodels.EnumFields.create(service=3, currency="EUR") - values = await testmodels.EnumFields.get(id=obj1.id).values_list("currency", flat=True) - self.assertEqual(values, testmodels.Currency.EUR) - def test_auto_maxlen(self): - fld = CharEnumField(testmodels.Currency) - self.assertEqual(fld.max_length, 3) +@pytest.mark.asyncio +async def test_char_enum_update(db): + obj0 = await testmodels.EnumFields.create( + service=testmodels.Service.system_administration, currency=testmodels.Currency.HUF + ) + await testmodels.EnumFields.filter(id=obj0.id).update(currency=testmodels.Currency.EUR) + obj = await testmodels.EnumFields.get(id=obj0.id) + assert obj.currency == testmodels.Currency.EUR + + await testmodels.EnumFields.filter(id=obj0.id).update(currency="USD") + obj = await testmodels.EnumFields.get(id=obj0.id) + assert obj.currency == testmodels.Currency.USD + with pytest.raises(ValueError): + await testmodels.EnumFields.filter(id=obj0.id).update(currency="XXX") + + +@pytest.mark.asyncio +async def test_char_enum_values(db): + obj0 = await testmodels.EnumFields.create( + service=testmodels.Service.system_administration, currency=testmodels.Currency.EUR + ) + values = await testmodels.EnumFields.get(id=obj0.id).values("currency") + assert values["currency"] == testmodels.Currency.EUR + + obj1 = await testmodels.EnumFields.create(service=3, currency="EUR") + values = await testmodels.EnumFields.get(id=obj1.id).values("currency") + assert values["currency"] == testmodels.Currency.EUR + + +@pytest.mark.asyncio +async def test_char_enum_values_list(db): + obj0 = await testmodels.EnumFields.create( + service=testmodels.Service.system_administration, currency=testmodels.Currency.EUR + ) + values = await testmodels.EnumFields.get(id=obj0.id).values_list("currency", flat=True) + assert values == testmodels.Currency.EUR + + obj1 = await testmodels.EnumFields.create(service=3, currency="EUR") + values = await testmodels.EnumFields.get(id=obj1.id).values_list("currency", flat=True) + assert values == testmodels.Currency.EUR + + +def test_char_enum_auto_maxlen(): + fld = CharEnumField(testmodels.Currency) + assert fld.max_length == 3 + + +def test_char_enum_defined_maxlen(): + fld = CharEnumField(testmodels.Currency, max_length=5) + assert fld.max_length == 5 + - def test_defined_maxlen(self): - fld = CharEnumField(testmodels.Currency, max_length=5) - self.assertEqual(fld.max_length, 5) +def test_char_enum_auto_description(): + fld = CharEnumField(testmodels.Currency) + assert fld.description == "HUF: HUF\nEUR: EUR\nUSD: USD" - def test_auto_description(self): - fld = CharEnumField(testmodels.Currency) - self.assertEqual(fld.description, "HUF: HUF\nEUR: EUR\nUSD: USD") - def test_manual_description(self): - fld = CharEnumField(testmodels.Currency, description="baa") - self.assertEqual(fld.description, "baa") +def test_char_enum_manual_description(): + fld = CharEnumField(testmodels.Currency, description="baa") + assert fld.description == "baa" diff --git a/tests/fields/test_fk.py b/tests/fields/test_fk.py index c50657229..c07be5b5e 100644 --- a/tests/fields/test_fk.py +++ b/tests/fields/test_fk.py @@ -1,5 +1,6 @@ +import pytest + from tests import testmodels -from tortoise.contrib import test from tortoise.exceptions import ( IntegrityError, NoValuesFetched, @@ -9,292 +10,366 @@ from tortoise.queryset import QuerySet -class TestForeignKeyField(test.TestCase): - def assertRaisesWrongTypeException(self, relation_name: str): - return self.assertRaisesRegex( - ValidationError, f"Invalid type for relationship field '{relation_name}'" - ) +def assert_raises_wrong_type_exception(relation_name: str): + """Context manager that asserts ValidationError with wrong type message.""" + return pytest.raises( + ValidationError, match=f"Invalid type for relationship field '{relation_name}'" + ) + + +@pytest.mark.asyncio +async def test_empty(db): + with pytest.raises(IntegrityError): + await testmodels.MinRelation.create() + + +@pytest.mark.asyncio +async def test_minimal__create_by_id(db): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.MinRelation.create(tournament_id=tour.id) + assert rel.tournament_id == tour.id + assert (await tour.minrelations.all())[0] == rel + + +@pytest.mark.asyncio +async def test_minimal__create_by_name(db): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.MinRelation.create(tournament=tour) + await rel.fetch_related("tournament") + assert rel.tournament == tour + assert (await tour.minrelations.all())[0] == rel + + +@pytest.mark.asyncio +async def test_minimal__by_name__created_prefetched(db): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.MinRelation.create(tournament=tour) + assert rel.tournament == tour + assert (await tour.minrelations.all())[0] == rel + + +@pytest.mark.asyncio +async def test_minimal__by_name__unfetched(db): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.MinRelation.create(tournament=tour) + rel = await testmodels.MinRelation.get(id=rel.id) + assert isinstance(rel.tournament, QuerySet) + + +@pytest.mark.asyncio +async def test_minimal__by_name__re_awaited(db): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.MinRelation.create(tournament=tour) + await rel.fetch_related("tournament") + assert rel.tournament == tour + assert await rel.tournament == tour + + +@pytest.mark.asyncio +async def test_minimal__by_name__awaited(db): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.MinRelation.create(tournament=tour) + rel = await testmodels.MinRelation.get(id=rel.id) + assert await rel.tournament == tour + assert (await tour.minrelations.all())[0] == rel + + +@pytest.mark.asyncio +async def test_event__create_by_id(db): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.Event.create(name="Event1", tournament_id=tour.id) + assert rel.tournament_id == tour.id + assert (await tour.events.all())[0] == rel + + +@pytest.mark.asyncio +async def test_event__create_by_name(db): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.Event.create(name="Event1", tournament=tour) + await rel.fetch_related("tournament") + assert rel.tournament == tour + assert (await tour.events.all())[0] == rel - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels.MinRelation.create() - - async def test_minimal__create_by_id(self): - tour = await testmodels.Tournament.create(name="Team1") - rel = await testmodels.MinRelation.create(tournament_id=tour.id) - self.assertEqual(rel.tournament_id, tour.id) - self.assertEqual((await tour.minrelations.all())[0], rel) - - async def test_minimal__create_by_name(self): - tour = await testmodels.Tournament.create(name="Team1") - rel = await testmodels.MinRelation.create(tournament=tour) - await rel.fetch_related("tournament") - self.assertEqual(rel.tournament, tour) - self.assertEqual((await tour.minrelations.all())[0], rel) - - async def test_minimal__by_name__created_prefetched(self): - tour = await testmodels.Tournament.create(name="Team1") - rel = await testmodels.MinRelation.create(tournament=tour) - self.assertEqual(rel.tournament, tour) - self.assertEqual((await tour.minrelations.all())[0], rel) - - async def test_minimal__by_name__unfetched(self): - tour = await testmodels.Tournament.create(name="Team1") - rel = await testmodels.MinRelation.create(tournament=tour) - rel = await testmodels.MinRelation.get(id=rel.id) - self.assertIsInstance(rel.tournament, QuerySet) - - async def test_minimal__by_name__re_awaited(self): - tour = await testmodels.Tournament.create(name="Team1") - rel = await testmodels.MinRelation.create(tournament=tour) - await rel.fetch_related("tournament") - self.assertEqual(rel.tournament, tour) - self.assertEqual(await rel.tournament, tour) - - async def test_minimal__by_name__awaited(self): - tour = await testmodels.Tournament.create(name="Team1") - rel = await testmodels.MinRelation.create(tournament=tour) - rel = await testmodels.MinRelation.get(id=rel.id) - self.assertEqual(await rel.tournament, tour) - self.assertEqual((await tour.minrelations.all())[0], rel) - - async def test_event__create_by_id(self): - tour = await testmodels.Tournament.create(name="Team1") - rel = await testmodels.Event.create(name="Event1", tournament_id=tour.id) - self.assertEqual(rel.tournament_id, tour.id) - self.assertEqual((await tour.events.all())[0], rel) - - async def test_event__create_by_name(self): - tour = await testmodels.Tournament.create(name="Team1") - rel = await testmodels.Event.create(name="Event1", tournament=tour) - await rel.fetch_related("tournament") - self.assertEqual(rel.tournament, tour) - self.assertEqual((await tour.events.all())[0], rel) - - async def test_update_by_name(self): - tour = await testmodels.Tournament.create(name="Team1") - tour2 = await testmodels.Tournament.create(name="Team2") - rel0 = await testmodels.Event.create(name="Event1", tournament=tour) - - await testmodels.Event.filter(pk=rel0.pk).update(tournament=tour2) - rel = await testmodels.Event.get(event_id=rel0.event_id) - - await rel.fetch_related("tournament") - self.assertEqual(rel.tournament, tour2) - self.assertEqual(await tour.events.all(), []) - self.assertEqual((await tour2.events.all())[0], rel) - - async def test_update_by_id(self): - tour = await testmodels.Tournament.create(name="Team1") - tour2 = await testmodels.Tournament.create(name="Team2") - rel0 = await testmodels.Event.create(name="Event1", tournament_id=tour.id) - - await testmodels.Event.filter(event_id=rel0.event_id).update(tournament_id=tour2.id) - rel = await testmodels.Event.get(pk=rel0.pk) - - self.assertEqual(rel.tournament_id, tour2.id) - self.assertEqual(await tour.events.all(), []) - self.assertEqual((await tour2.events.all())[0], rel) - - async def test_minimal__uninstantiated_create(self): - tour = testmodels.Tournament(name="Team1") - with self.assertRaisesRegex(OperationalError, "You should first call .save()"): - await testmodels.MinRelation.create(tournament=tour) - - async def test_minimal__uninstantiated_iterate(self): - tour = testmodels.Tournament(name="Team1") - with self.assertRaisesRegex( - OperationalError, "This objects hasn't been instanced, call .save()" - ): - async for _ in tour.minrelations: - pass - - async def test_minimal__uninstantiated_await(self): - tour = testmodels.Tournament(name="Team1") - with self.assertRaisesRegex( - OperationalError, "This objects hasn't been instanced, call .save()" - ): - await tour.minrelations - - async def test_minimal__unfetched_contains(self): - tour = await testmodels.Tournament.create(name="Team1") - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - "a" in tour.minrelations # pylint: disable=W0104 - - async def test_minimal__unfetched_iter(self): - tour = await testmodels.Tournament.create(name="Team1") - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - for _ in tour.minrelations: - pass - - async def test_minimal__unfetched_len(self): - tour = await testmodels.Tournament.create(name="Team1") - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - len(tour.minrelations) - - async def test_minimal__unfetched_bool(self): - tour = await testmodels.Tournament.create(name="Team1") - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - bool(tour.minrelations) - - async def test_minimal__unfetched_getitem(self): - tour = await testmodels.Tournament.create(name="Team1") - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - tour.minrelations[0] # pylint: disable=W0104 - - async def test_minimal__instantiated_create(self): - tour = await testmodels.Tournament.create(name="Team1") + +@pytest.mark.asyncio +async def test_update_by_name(db): + tour = await testmodels.Tournament.create(name="Team1") + tour2 = await testmodels.Tournament.create(name="Team2") + rel0 = await testmodels.Event.create(name="Event1", tournament=tour) + + await testmodels.Event.filter(pk=rel0.pk).update(tournament=tour2) + rel = await testmodels.Event.get(event_id=rel0.event_id) + + await rel.fetch_related("tournament") + assert rel.tournament == tour2 + assert await tour.events.all() == [] + assert (await tour2.events.all())[0] == rel + + +@pytest.mark.asyncio +async def test_update_by_id(db): + tour = await testmodels.Tournament.create(name="Team1") + tour2 = await testmodels.Tournament.create(name="Team2") + rel0 = await testmodels.Event.create(name="Event1", tournament_id=tour.id) + + await testmodels.Event.filter(event_id=rel0.event_id).update(tournament_id=tour2.id) + rel = await testmodels.Event.get(pk=rel0.pk) + + assert rel.tournament_id == tour2.id + assert await tour.events.all() == [] + assert (await tour2.events.all())[0] == rel + + +@pytest.mark.asyncio +async def test_minimal__uninstantiated_create(db): + tour = testmodels.Tournament(name="Team1") + with pytest.raises(OperationalError, match="You should first call .save()"): await testmodels.MinRelation.create(tournament=tour) - async def test_minimal__instantiated_create_wrong_type(self): - author = await testmodels.Author.create(name="Author1") - with self.assertRaisesWrongTypeException("tournament"): - await testmodels.MinRelation.create(tournament=author) - async def test_minimal__instantiated_iterate(self): - tour = await testmodels.Tournament.create(name="Team1") +@pytest.mark.asyncio +async def test_minimal__uninstantiated_iterate(db): + tour = testmodels.Tournament(name="Team1") + with pytest.raises(OperationalError, match="This objects hasn't been instanced, call .save()"): async for _ in tour.minrelations: pass - async def test_minimal__instantiated_await(self): - tour = await testmodels.Tournament.create(name="Team1") + +@pytest.mark.asyncio +async def test_minimal__uninstantiated_await(db): + tour = testmodels.Tournament(name="Team1") + with pytest.raises(OperationalError, match="This objects hasn't been instanced, call .save()"): await tour.minrelations - async def test_minimal__fetched_contains(self): - tour = await testmodels.Tournament.create(name="Team1") - rel = await testmodels.MinRelation.create(tournament=tour) - await tour.fetch_related("minrelations") - self.assertTrue(rel in tour.minrelations) - async def test_minimal__fetched_iter(self): - tour = await testmodels.Tournament.create(name="Team1") - rel = await testmodels.MinRelation.create(tournament=tour) - await tour.fetch_related("minrelations") - self.assertEqual(list(tour.minrelations), [rel]) +@pytest.mark.asyncio +async def test_minimal__unfetched_contains(db): + tour = await testmodels.Tournament.create(name="Team1") + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + "a" in tour.minrelations # pylint: disable=W0104 + + +@pytest.mark.asyncio +async def test_minimal__unfetched_iter(db): + tour = await testmodels.Tournament.create(name="Team1") + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + for _ in tour.minrelations: + pass - async def test_minimal__fetched_len(self): - tour = await testmodels.Tournament.create(name="Team1") - await testmodels.MinRelation.create(tournament=tour) - await tour.fetch_related("minrelations") - self.assertEqual(len(tour.minrelations), 1) - async def test_minimal__fetched_bool(self): - tour = await testmodels.Tournament.create(name="Team1") - await tour.fetch_related("minrelations") - self.assertFalse(bool(tour.minrelations)) - await testmodels.MinRelation.create(tournament=tour) - await tour.fetch_related("minrelations") - self.assertTrue(bool(tour.minrelations)) - - async def test_minimal__fetched_getitem(self): - tour = await testmodels.Tournament.create(name="Team1") - rel = await testmodels.MinRelation.create(tournament=tour) - await tour.fetch_related("minrelations") - self.assertEqual(tour.minrelations[0], rel) - - with self.assertRaises(IndexError): - tour.minrelations[1] # pylint: disable=W0104 - - async def test_event__filter(self): - tour = await testmodels.Tournament.create(name="Team1") - event1 = await testmodels.Event.create(name="Event1", tournament=tour) - event2 = await testmodels.Event.create(name="Event2", tournament=tour) - self.assertEqual(await tour.events.filter(name="Event1"), [event1]) - self.assertEqual(await tour.events.filter(name="Event2"), [event2]) - self.assertEqual(await tour.events.filter(name="Event3"), []) - - async def test_event__all(self): - tour = await testmodels.Tournament.create(name="Team1") - event1 = await testmodels.Event.create(name="Event1", tournament=tour) - event2 = await testmodels.Event.create(name="Event2", tournament=tour) - self.assertSetEqual(set(await tour.events.all()), {event1, event2}) - - async def test_event__order_by(self): - tour = await testmodels.Tournament.create(name="Team1") - event1 = await testmodels.Event.create(name="Event1", tournament=tour) - event2 = await testmodels.Event.create(name="Event2", tournament=tour) - self.assertEqual(await tour.events.order_by("-name"), [event2, event1]) - self.assertEqual(await tour.events.order_by("name"), [event1, event2]) - - async def test_event__limit(self): - tour = await testmodels.Tournament.create(name="Team1") - event1 = await testmodels.Event.create(name="Event1", tournament=tour) - event2 = await testmodels.Event.create(name="Event2", tournament=tour) - await testmodels.Event.create(name="Event3", tournament=tour) - self.assertEqual(await tour.events.limit(2).order_by("name"), [event1, event2]) - - async def test_event__offset(self): - tour = await testmodels.Tournament.create(name="Team1") - await testmodels.Event.create(name="Event1", tournament=tour) - event2 = await testmodels.Event.create(name="Event2", tournament=tour) - event3 = await testmodels.Event.create(name="Event3", tournament=tour) - self.assertEqual(await tour.events.offset(1).order_by("name"), [event2, event3]) - - async def test_fk_correct_type_assignment(self): - tour1 = await testmodels.Tournament.create(name="Team1") - tour2 = await testmodels.Tournament.create(name="Team2") - event = await testmodels.Event(name="Event1", tournament=tour1) - - event.tournament = tour2 - await event.save() - self.assertEqual(event.tournament_id, tour2.id) - - async def test_fk_wrong_type_assignment(self): - tour = await testmodels.Tournament.create(name="Team1") - author = await testmodels.Author.create(name="Author") - rel = await testmodels.MinRelation.create(tournament=tour) - - with self.assertRaisesWrongTypeException("tournament"): - rel.tournament = author - - async def test_fk_none_assignment(self): - manager = await testmodels.Employee.create(name="Manager") - employee = await testmodels.Employee.create(name="Employee", manager=manager) - - employee.manager = None - await employee.save() - self.assertIsNone(employee.manager) - - async def test_fk_update_wrong_type(self): - tour = await testmodels.Tournament.create(name="Team1") - rel = await testmodels.MinRelation.create(tournament=tour) - author = await testmodels.Author.create(name="Author1") - - with self.assertRaisesWrongTypeException("tournament"): - await testmodels.MinRelation.filter(id=rel.id).update(tournament=author) - - async def test_fk_bulk_create_wrong_type(self): - author = await testmodels.Author.create(name="Author") - with self.assertRaisesWrongTypeException("tournament"): - await testmodels.MinRelation.bulk_create( - [testmodels.MinRelation(tournament=author) for _ in range(10)] - ) - - async def test_fk_bulk_update_wrong_type(self): - tour = await testmodels.Tournament.create(name="Team1") +@pytest.mark.asyncio +async def test_minimal__unfetched_len(db): + tour = await testmodels.Tournament.create(name="Team1") + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + len(tour.minrelations) + + +@pytest.mark.asyncio +async def test_minimal__unfetched_bool(db): + tour = await testmodels.Tournament.create(name="Team1") + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + bool(tour.minrelations) + + +@pytest.mark.asyncio +async def test_minimal__unfetched_getitem(db): + tour = await testmodels.Tournament.create(name="Team1") + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + tour.minrelations[0] # pylint: disable=W0104 + + +@pytest.mark.asyncio +async def test_minimal__instantiated_create(db): + tour = await testmodels.Tournament.create(name="Team1") + await testmodels.MinRelation.create(tournament=tour) + + +@pytest.mark.asyncio +async def test_minimal__instantiated_create_wrong_type(db): + author = await testmodels.Author.create(name="Author1") + with assert_raises_wrong_type_exception("tournament"): + await testmodels.MinRelation.create(tournament=author) + + +@pytest.mark.asyncio +async def test_minimal__instantiated_iterate(db): + tour = await testmodels.Tournament.create(name="Team1") + async for _ in tour.minrelations: + pass + + +@pytest.mark.asyncio +async def test_minimal__instantiated_await(db): + tour = await testmodels.Tournament.create(name="Team1") + await tour.minrelations + + +@pytest.mark.asyncio +async def test_minimal__fetched_contains(db): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.MinRelation.create(tournament=tour) + await tour.fetch_related("minrelations") + assert rel in tour.minrelations + + +@pytest.mark.asyncio +async def test_minimal__fetched_iter(db): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.MinRelation.create(tournament=tour) + await tour.fetch_related("minrelations") + assert list(tour.minrelations) == [rel] + + +@pytest.mark.asyncio +async def test_minimal__fetched_len(db): + tour = await testmodels.Tournament.create(name="Team1") + await testmodels.MinRelation.create(tournament=tour) + await tour.fetch_related("minrelations") + assert len(tour.minrelations) == 1 + + +@pytest.mark.asyncio +async def test_minimal__fetched_bool(db): + tour = await testmodels.Tournament.create(name="Team1") + await tour.fetch_related("minrelations") + assert not bool(tour.minrelations) + await testmodels.MinRelation.create(tournament=tour) + await tour.fetch_related("minrelations") + assert bool(tour.minrelations) + + +@pytest.mark.asyncio +async def test_minimal__fetched_getitem(db): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.MinRelation.create(tournament=tour) + await tour.fetch_related("minrelations") + assert tour.minrelations[0] == rel + + with pytest.raises(IndexError): + tour.minrelations[1] # pylint: disable=W0104 + + +@pytest.mark.asyncio +async def test_event__filter(db): + tour = await testmodels.Tournament.create(name="Team1") + event1 = await testmodels.Event.create(name="Event1", tournament=tour) + event2 = await testmodels.Event.create(name="Event2", tournament=tour) + assert await tour.events.filter(name="Event1") == [event1] + assert await tour.events.filter(name="Event2") == [event2] + assert await tour.events.filter(name="Event3") == [] + + +@pytest.mark.asyncio +async def test_event__all(db): + tour = await testmodels.Tournament.create(name="Team1") + event1 = await testmodels.Event.create(name="Event1", tournament=tour) + event2 = await testmodels.Event.create(name="Event2", tournament=tour) + assert set(await tour.events.all()) == {event1, event2} + + +@pytest.mark.asyncio +async def test_event__order_by(db): + tour = await testmodels.Tournament.create(name="Team1") + event1 = await testmodels.Event.create(name="Event1", tournament=tour) + event2 = await testmodels.Event.create(name="Event2", tournament=tour) + assert await tour.events.order_by("-name") == [event2, event1] + assert await tour.events.order_by("name") == [event1, event2] + + +@pytest.mark.asyncio +async def test_event__limit(db): + tour = await testmodels.Tournament.create(name="Team1") + event1 = await testmodels.Event.create(name="Event1", tournament=tour) + event2 = await testmodels.Event.create(name="Event2", tournament=tour) + await testmodels.Event.create(name="Event3", tournament=tour) + assert await tour.events.limit(2).order_by("name") == [event1, event2] + + +@pytest.mark.asyncio +async def test_event__offset(db): + tour = await testmodels.Tournament.create(name="Team1") + await testmodels.Event.create(name="Event1", tournament=tour) + event2 = await testmodels.Event.create(name="Event2", tournament=tour) + event3 = await testmodels.Event.create(name="Event3", tournament=tour) + assert await tour.events.offset(1).order_by("name") == [event2, event3] + + +@pytest.mark.asyncio +async def test_fk_correct_type_assignment(db): + tour1 = await testmodels.Tournament.create(name="Team1") + tour2 = await testmodels.Tournament.create(name="Team2") + event = await testmodels.Event(name="Event1", tournament=tour1) + + event.tournament = tour2 + await event.save() + assert event.tournament_id == tour2.id + + +@pytest.mark.asyncio +async def test_fk_wrong_type_assignment(db): + tour = await testmodels.Tournament.create(name="Team1") + author = await testmodels.Author.create(name="Author") + rel = await testmodels.MinRelation.create(tournament=tour) + + with assert_raises_wrong_type_exception("tournament"): + rel.tournament = author + + +@pytest.mark.asyncio +async def test_fk_none_assignment(db): + manager = await testmodels.Employee.create(name="Manager") + employee = await testmodels.Employee.create(name="Employee", manager=manager) + + employee.manager = None + await employee.save() + assert employee.manager is None + + +@pytest.mark.asyncio +async def test_fk_update_wrong_type(db): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.MinRelation.create(tournament=tour) + author = await testmodels.Author.create(name="Author1") + + with assert_raises_wrong_type_exception("tournament"): + await testmodels.MinRelation.filter(id=rel.id).update(tournament=author) + + +@pytest.mark.asyncio +async def test_fk_bulk_create_wrong_type(db): + author = await testmodels.Author.create(name="Author") + with assert_raises_wrong_type_exception("tournament"): await testmodels.MinRelation.bulk_create( - [testmodels.MinRelation(tournament=tour) for _ in range(1, 10)] + [testmodels.MinRelation(tournament=author) for _ in range(10)] + ) + + +@pytest.mark.asyncio +async def test_fk_bulk_update_wrong_type(db): + tour = await testmodels.Tournament.create(name="Team1") + await testmodels.MinRelation.bulk_create( + [testmodels.MinRelation(tournament=tour) for _ in range(1, 10)] + ) + author = await testmodels.Author.create(name="Author") + + with assert_raises_wrong_type_exception("tournament"): + relations = await testmodels.MinRelation.all() + await testmodels.MinRelation.bulk_update( + [testmodels.MinRelation(id=rel.id, tournament=author) for rel in relations], + fields=["tournament"], ) - author = await testmodels.Author.create(name="Author") - - with self.assertRaisesWrongTypeException("tournament"): - relations = await testmodels.MinRelation.all() - await testmodels.MinRelation.bulk_update( - [testmodels.MinRelation(id=rel.id, tournament=author) for rel in relations], - fields=["tournament"], - ) diff --git a/tests/fields/test_fk_uuid.py b/tests/fields/test_fk_uuid.py index 31c88ad87..952007824 100644 --- a/tests/fields/test_fk_uuid.py +++ b/tests/fields/test_fk_uuid.py @@ -1,306 +1,427 @@ +import pytest + from tests import testmodels -from tortoise.contrib import test from tortoise.exceptions import IntegrityError, NoValuesFetched, OperationalError from tortoise.queryset import QuerySet -class TestForeignKeyUUIDField(test.TestCase): +# Parameterize to test both standard and source-field models +@pytest.fixture( + params=[ + pytest.param( + ( + testmodels.UUIDPkModel, + testmodels.UUIDFkRelatedModel, + testmodels.UUIDFkRelatedNullModel, + ), + id="standard", + ), + pytest.param( + ( + testmodels.UUIDPkSourceModel, + testmodels.UUIDFkRelatedSourceModel, + testmodels.UUIDFkRelatedNullSourceModel, + ), + id="sourced", + ), + ] +) +def uuid_models(request): """ - Here we do the same FK tests but using UUID. The reason this is useful is: + Fixture providing UUID model classes for FK tests. + Tests both standard UUID models and source-field models: * UUID needs escaping, so a good indicator of where we may have missed it. * UUID populates a value BEFORE it gets committed to DB, whereas int is AFTER. * UUID is stored differently for different DB backends. (native in PG) + + The sourced variant tests identical Python-like models with customized DB names, + helping ensure we don't confuse the two concepts. """ + return request.param + - UUIDPkModel = testmodels.UUIDPkModel - UUIDFkRelatedModel = testmodels.UUIDFkRelatedModel - UUIDFkRelatedNullModel = testmodels.UUIDFkRelatedNullModel - - async def test_empty(self): - with self.assertRaises(IntegrityError): - await self.UUIDFkRelatedModel.create() - - async def test_empty_null(self): - await self.UUIDFkRelatedNullModel.create() - - async def test_create_by_id(self): - tour = await self.UUIDPkModel.create() - rel = await self.UUIDFkRelatedModel.create(model_id=tour.id) - self.assertEqual(rel.model_id, tour.id) - self.assertEqual((await tour.children.all())[0], rel) - - async def test_create_by_name(self): - tour = await self.UUIDPkModel.create() - rel = await self.UUIDFkRelatedModel.create(model=tour) - await rel.fetch_related("model") - self.assertEqual(rel.model, tour) - self.assertEqual((await tour.children.all())[0], rel) - - async def test_by_name__created_prefetched(self): - tour = await self.UUIDPkModel.create() - rel = await self.UUIDFkRelatedModel.create(model=tour) - self.assertEqual(rel.model, tour) - self.assertEqual((await tour.children.all())[0], rel) - - async def test_by_name__unfetched(self): - tour = await self.UUIDPkModel.create() - rel = await self.UUIDFkRelatedModel.create(model=tour) - rel = await self.UUIDFkRelatedModel.get(id=rel.id) - self.assertIsInstance(rel.model, QuerySet) - - async def test_by_name__re_awaited(self): - tour = await self.UUIDPkModel.create() - rel = await self.UUIDFkRelatedModel.create(model=tour) - await rel.fetch_related("model") - self.assertEqual(rel.model, tour) - self.assertEqual(await rel.model, tour) - - async def test_by_name__awaited(self): - tour = await self.UUIDPkModel.create() - rel = await self.UUIDFkRelatedModel.create(model=tour) - rel = await self.UUIDFkRelatedModel.get(id=rel.id) - self.assertEqual(await rel.model, tour) - self.assertEqual((await tour.children.all())[0], rel) - - async def test_update_by_name(self): - tour = await self.UUIDPkModel.create() - tour2 = await self.UUIDPkModel.create() - rel0 = await self.UUIDFkRelatedModel.create(model=tour) - - await self.UUIDFkRelatedModel.filter(id=rel0.id).update(model=tour2) - rel = await self.UUIDFkRelatedModel.get(id=rel0.id) - - await rel.fetch_related("model") - self.assertEqual(rel.model, tour2) - self.assertEqual(await tour.children.all(), []) - self.assertEqual((await tour2.children.all())[0], rel) - - async def test_update_by_id(self): - tour = await self.UUIDPkModel.create() - tour2 = await self.UUIDPkModel.create() - rel0 = await self.UUIDFkRelatedModel.create(model_id=tour.id) - - await self.UUIDFkRelatedModel.filter(id=rel0.id).update(model_id=tour2.id) - rel = await self.UUIDFkRelatedModel.get(id=rel0.id) - - self.assertEqual(rel.model_id, tour2.id) - self.assertEqual(await tour.children.all(), []) - self.assertEqual((await tour2.children.all())[0], rel) - - async def test_uninstantiated_create(self): - tour = self.UUIDPkModel() - with self.assertRaisesRegex(OperationalError, "You should first call .save()"): - await self.UUIDFkRelatedModel.create(model=tour) - - async def test_uninstantiated_iterate(self): - tour = self.UUIDPkModel() - with self.assertRaisesRegex( - OperationalError, "This objects hasn't been instanced, call .save()" - ): - async for _ in tour.children: - pass - - async def test_uninstantiated_await(self): - tour = self.UUIDPkModel() - with self.assertRaisesRegex( - OperationalError, "This objects hasn't been instanced, call .save()" - ): - await tour.children - - async def test_unfetched_contains(self): - tour = await self.UUIDPkModel.create() - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - "a" in tour.children # pylint: disable=W0104 - - async def test_unfetched_iter(self): - tour = await self.UUIDPkModel.create() - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - for _ in tour.children: - pass - - async def test_unfetched_len(self): - tour = await self.UUIDPkModel.create() - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - len(tour.children) - - async def test_unfetched_bool(self): - tour = await self.UUIDPkModel.create() - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - bool(tour.children) - - async def test_unfetched_getitem(self): - tour = await self.UUIDPkModel.create() - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - tour.children[0] # pylint: disable=W0104 - - async def test_instantiated_create(self): - tour = await self.UUIDPkModel.create() - await self.UUIDFkRelatedModel.create(model=tour) - - async def test_instantiated_iterate(self): - tour = await self.UUIDPkModel.create() +@pytest.mark.asyncio +async def test_empty(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + with pytest.raises(IntegrityError): + await UUIDFkRelatedModel.create() + + +@pytest.mark.asyncio +async def test_empty_null(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + await UUIDFkRelatedNullModel.create() + + +@pytest.mark.asyncio +async def test_create_by_id(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + rel = await UUIDFkRelatedModel.create(model_id=tour.id) + assert rel.model_id == tour.id + assert (await tour.children.all())[0] == rel + + +@pytest.mark.asyncio +async def test_create_by_name(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + rel = await UUIDFkRelatedModel.create(model=tour) + await rel.fetch_related("model") + assert rel.model == tour + assert (await tour.children.all())[0] == rel + + +@pytest.mark.asyncio +async def test_by_name__created_prefetched(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + rel = await UUIDFkRelatedModel.create(model=tour) + assert rel.model == tour + assert (await tour.children.all())[0] == rel + + +@pytest.mark.asyncio +async def test_by_name__unfetched(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + rel = await UUIDFkRelatedModel.create(model=tour) + rel = await UUIDFkRelatedModel.get(id=rel.id) + assert isinstance(rel.model, QuerySet) + + +@pytest.mark.asyncio +async def test_by_name__re_awaited(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + rel = await UUIDFkRelatedModel.create(model=tour) + await rel.fetch_related("model") + assert rel.model == tour + assert await rel.model == tour + + +@pytest.mark.asyncio +async def test_by_name__awaited(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + rel = await UUIDFkRelatedModel.create(model=tour) + rel = await UUIDFkRelatedModel.get(id=rel.id) + assert await rel.model == tour + assert (await tour.children.all())[0] == rel + + +@pytest.mark.asyncio +async def test_update_by_name(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + tour2 = await UUIDPkModel.create() + rel0 = await UUIDFkRelatedModel.create(model=tour) + + await UUIDFkRelatedModel.filter(id=rel0.id).update(model=tour2) + rel = await UUIDFkRelatedModel.get(id=rel0.id) + + await rel.fetch_related("model") + assert rel.model == tour2 + assert await tour.children.all() == [] + assert (await tour2.children.all())[0] == rel + + +@pytest.mark.asyncio +async def test_update_by_id(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + tour2 = await UUIDPkModel.create() + rel0 = await UUIDFkRelatedModel.create(model_id=tour.id) + + await UUIDFkRelatedModel.filter(id=rel0.id).update(model_id=tour2.id) + rel = await UUIDFkRelatedModel.get(id=rel0.id) + + assert rel.model_id == tour2.id + assert await tour.children.all() == [] + assert (await tour2.children.all())[0] == rel + + +@pytest.mark.asyncio +async def test_uninstantiated_create(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = UUIDPkModel() + with pytest.raises(OperationalError, match="You should first call .save()"): + await UUIDFkRelatedModel.create(model=tour) + + +@pytest.mark.asyncio +async def test_uninstantiated_iterate(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = UUIDPkModel() + with pytest.raises(OperationalError, match="This objects hasn't been instanced, call .save()"): async for _ in tour.children: pass - async def test_instantiated_await(self): - tour = await self.UUIDPkModel.create() + +@pytest.mark.asyncio +async def test_uninstantiated_await(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = UUIDPkModel() + with pytest.raises(OperationalError, match="This objects hasn't been instanced, call .save()"): await tour.children - async def test_minimal__fetched_contains(self): - tour = await self.UUIDPkModel.create() - rel = await self.UUIDFkRelatedModel.create(model=tour) - await tour.fetch_related("children") - self.assertTrue(rel in tour.children) - - async def test_minimal__fetched_iter(self): - tour = await self.UUIDPkModel.create() - rel = await self.UUIDFkRelatedModel.create(model=tour) - await tour.fetch_related("children") - self.assertEqual(list(tour.children), [rel]) - - async def test_minimal__fetched_len(self): - tour = await self.UUIDPkModel.create() - await self.UUIDFkRelatedModel.create(model=tour) - await tour.fetch_related("children") - self.assertEqual(len(tour.children), 1) - - async def test_minimal__fetched_bool(self): - tour = await self.UUIDPkModel.create() - await tour.fetch_related("children") - self.assertFalse(bool(tour.children)) - await self.UUIDFkRelatedModel.create(model=tour) - await tour.fetch_related("children") - self.assertTrue(bool(tour.children)) - - async def test_minimal__fetched_getitem(self): - tour = await self.UUIDPkModel.create() - rel = await self.UUIDFkRelatedModel.create(model=tour) - await tour.fetch_related("children") - self.assertEqual(tour.children[0], rel) - - with self.assertRaises(IndexError): - tour.children[1] # pylint: disable=W0104 - - async def test_event__filter(self): - tour = await self.UUIDPkModel.create() - event1 = await self.UUIDFkRelatedModel.create(name="Event1", model=tour) - event2 = await self.UUIDFkRelatedModel.create(name="Event2", model=tour) - self.assertEqual(await tour.children.filter(name="Event1"), [event1]) - self.assertEqual(await tour.children.filter(name="Event2"), [event2]) - self.assertEqual(await tour.children.filter(name="Event3"), []) - - async def test_event__all(self): - tour = await self.UUIDPkModel.create() - event1 = await self.UUIDFkRelatedModel.create(name="Event1", model=tour) - event2 = await self.UUIDFkRelatedModel.create(name="Event2", model=tour) - self.assertSetEqual(set(await tour.children.all()), {event1, event2}) - - async def test_event__order_by(self): - tour = await self.UUIDPkModel.create() - event1 = await self.UUIDFkRelatedModel.create(name="Event1", model=tour) - event2 = await self.UUIDFkRelatedModel.create(name="Event2", model=tour) - self.assertEqual(await tour.children.order_by("-name"), [event2, event1]) - self.assertEqual(await tour.children.order_by("name"), [event1, event2]) - - async def test_event__limit(self): - tour = await self.UUIDPkModel.create() - event1 = await self.UUIDFkRelatedModel.create(name="Event1", model=tour) - event2 = await self.UUIDFkRelatedModel.create(name="Event2", model=tour) - await self.UUIDFkRelatedModel.create(name="Event3", model=tour) - self.assertEqual(await tour.children.limit(2).order_by("name"), [event1, event2]) - - async def test_event__offset(self): - tour = await self.UUIDPkModel.create() - await self.UUIDFkRelatedModel.create(name="Event1", model=tour) - event2 = await self.UUIDFkRelatedModel.create(name="Event2", model=tour) - event3 = await self.UUIDFkRelatedModel.create(name="Event3", model=tour) - self.assertEqual(await tour.children.offset(1).order_by("name"), [event2, event3]) - - async def test_assign_by_id(self): - tour = await self.UUIDPkModel.create() - event = await self.UUIDFkRelatedNullModel.create(model=None) - event.model_id = tour.id - await event.save() - event0 = await self.UUIDFkRelatedNullModel.get(id=event.id) - self.assertEqual(event0.model_id, tour.id) - await event0.fetch_related("model") - self.assertEqual(event0.model, tour) - - async def test_assign_by_name(self): - tour = await self.UUIDPkModel.create() - event = await self.UUIDFkRelatedNullModel.create(model=None) - event.model = tour - await event.save() - event0 = await self.UUIDFkRelatedNullModel.get(id=event.id) - self.assertEqual(event0.model_id, tour.id) - await event0.fetch_related("model") - self.assertEqual(event0.model, tour) - - async def test_assign_none_by_id(self): - tour = await self.UUIDPkModel.create() - event = await self.UUIDFkRelatedNullModel.create(model=tour) - event.model_id = None + +@pytest.mark.asyncio +async def test_unfetched_contains(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + "a" in tour.children # pylint: disable=W0104 + + +@pytest.mark.asyncio +async def test_unfetched_iter(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + for _ in tour.children: + pass + + +@pytest.mark.asyncio +async def test_unfetched_len(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + len(tour.children) + + +@pytest.mark.asyncio +async def test_unfetched_bool(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + bool(tour.children) + + +@pytest.mark.asyncio +async def test_unfetched_getitem(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + tour.children[0] # pylint: disable=W0104 + + +@pytest.mark.asyncio +async def test_instantiated_create(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + await UUIDFkRelatedModel.create(model=tour) + + +@pytest.mark.asyncio +async def test_instantiated_iterate(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + async for _ in tour.children: + pass + + +@pytest.mark.asyncio +async def test_instantiated_await(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + await tour.children + + +@pytest.mark.asyncio +async def test_minimal__fetched_contains(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + rel = await UUIDFkRelatedModel.create(model=tour) + await tour.fetch_related("children") + assert rel in tour.children + + +@pytest.mark.asyncio +async def test_minimal__fetched_iter(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + rel = await UUIDFkRelatedModel.create(model=tour) + await tour.fetch_related("children") + assert list(tour.children) == [rel] + + +@pytest.mark.asyncio +async def test_minimal__fetched_len(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + await UUIDFkRelatedModel.create(model=tour) + await tour.fetch_related("children") + assert len(tour.children) == 1 + + +@pytest.mark.asyncio +async def test_minimal__fetched_bool(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + await tour.fetch_related("children") + assert not bool(tour.children) + await UUIDFkRelatedModel.create(model=tour) + await tour.fetch_related("children") + assert bool(tour.children) + + +@pytest.mark.asyncio +async def test_minimal__fetched_getitem(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + rel = await UUIDFkRelatedModel.create(model=tour) + await tour.fetch_related("children") + assert tour.children[0] == rel + + with pytest.raises(IndexError): + tour.children[1] # pylint: disable=W0104 + + +@pytest.mark.asyncio +async def test_event__filter(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + event1 = await UUIDFkRelatedModel.create(name="Event1", model=tour) + event2 = await UUIDFkRelatedModel.create(name="Event2", model=tour) + assert await tour.children.filter(name="Event1") == [event1] + assert await tour.children.filter(name="Event2") == [event2] + assert await tour.children.filter(name="Event3") == [] + + +@pytest.mark.asyncio +async def test_event__all(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + event1 = await UUIDFkRelatedModel.create(name="Event1", model=tour) + event2 = await UUIDFkRelatedModel.create(name="Event2", model=tour) + assert set(await tour.children.all()) == {event1, event2} + + +@pytest.mark.asyncio +async def test_event__order_by(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + event1 = await UUIDFkRelatedModel.create(name="Event1", model=tour) + event2 = await UUIDFkRelatedModel.create(name="Event2", model=tour) + assert await tour.children.order_by("-name") == [event2, event1] + assert await tour.children.order_by("name") == [event1, event2] + + +@pytest.mark.asyncio +async def test_event__limit(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + event1 = await UUIDFkRelatedModel.create(name="Event1", model=tour) + event2 = await UUIDFkRelatedModel.create(name="Event2", model=tour) + await UUIDFkRelatedModel.create(name="Event3", model=tour) + assert await tour.children.limit(2).order_by("name") == [event1, event2] + + +@pytest.mark.asyncio +async def test_event__offset(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + await UUIDFkRelatedModel.create(name="Event1", model=tour) + event2 = await UUIDFkRelatedModel.create(name="Event2", model=tour) + event3 = await UUIDFkRelatedModel.create(name="Event3", model=tour) + assert await tour.children.offset(1).order_by("name") == [event2, event3] + + +@pytest.mark.asyncio +async def test_assign_by_id(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + event = await UUIDFkRelatedNullModel.create(model=None) + event.model_id = tour.id + await event.save() + event0 = await UUIDFkRelatedNullModel.get(id=event.id) + assert event0.model_id == tour.id + await event0.fetch_related("model") + assert event0.model == tour + + +@pytest.mark.asyncio +async def test_assign_by_name(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + event = await UUIDFkRelatedNullModel.create(model=None) + event.model = tour + await event.save() + event0 = await UUIDFkRelatedNullModel.get(id=event.id) + assert event0.model_id == tour.id + await event0.fetch_related("model") + assert event0.model == tour + + +@pytest.mark.asyncio +async def test_assign_none_by_id(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + event = await UUIDFkRelatedNullModel.create(model=tour) + event.model_id = None + await event.save() + event0 = await UUIDFkRelatedNullModel.get(id=event.id) + assert event0.model_id is None + await event0.fetch_related("model") + assert event0.model is None + + +@pytest.mark.asyncio +async def test_assign_none_by_name(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + event = await UUIDFkRelatedNullModel.create(model=tour) + event.model = None + await event.save() + event0 = await UUIDFkRelatedNullModel.get(id=event.id) + assert event0.model_id is None + await event0.fetch_related("model") + assert event0.model is None + + +@pytest.mark.asyncio +async def test_assign_none_by_id_fails(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + event = await UUIDFkRelatedModel.create(model=tour) + event.model_id = None + with pytest.raises(IntegrityError): await event.save() - event0 = await self.UUIDFkRelatedNullModel.get(id=event.id) - self.assertEqual(event0.model_id, None) - await event0.fetch_related("model") - self.assertEqual(event0.model, None) - - async def test_assign_none_by_name(self): - tour = await self.UUIDPkModel.create() - event = await self.UUIDFkRelatedNullModel.create(model=tour) - event.model = None + + +@pytest.mark.asyncio +async def test_assign_none_by_name_fails(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + event = await UUIDFkRelatedModel.create(model=tour) + event.model = None + with pytest.raises(IntegrityError): await event.save() - event0 = await self.UUIDFkRelatedNullModel.get(id=event.id) - self.assertEqual(event0.model_id, None) - await event0.fetch_related("model") - self.assertEqual(event0.model, None) - - async def test_assign_none_by_id_fails(self): - tour = await self.UUIDPkModel.create() - event = await self.UUIDFkRelatedModel.create(model=tour) - event.model_id = None - with self.assertRaises(IntegrityError): - await event.save() - - async def test_assign_none_by_name_fails(self): - tour = await self.UUIDPkModel.create() - event = await self.UUIDFkRelatedModel.create(model=tour) - event.model = None - with self.assertRaises(IntegrityError): - await event.save() - - async def test_delete_by_name(self): - tour = await self.UUIDPkModel.create() - event = await self.UUIDFkRelatedModel.create(model=tour) - del event.model - with self.assertRaises(IntegrityError): - await event.save() - - -class TestForeignKeyUUIDSourcedField(TestForeignKeyUUIDField): - """ - Here we test the identical Python-like models, but with all customized DB names. - This helps test that we don't confuse the two concepts. - """ - UUIDPkModel = testmodels.UUIDPkSourceModel # type: ignore - UUIDFkRelatedModel = testmodels.UUIDFkRelatedSourceModel # type: ignore - UUIDFkRelatedNullModel = testmodels.UUIDFkRelatedNullSourceModel # type: ignore +@pytest.mark.asyncio +async def test_delete_by_name(db, uuid_models): + UUIDPkModel, UUIDFkRelatedModel, UUIDFkRelatedNullModel = uuid_models + tour = await UUIDPkModel.create() + event = await UUIDFkRelatedModel.create(model=tour) + del event.model + with pytest.raises(IntegrityError): + await event.save() diff --git a/tests/fields/test_fk_with_unique.py b/tests/fields/test_fk_with_unique.py index 6dc71c852..235679a13 100644 --- a/tests/fields/test_fk_with_unique.py +++ b/tests/fields/test_fk_with_unique.py @@ -1,225 +1,282 @@ +import pytest + from tests import testmodels -from tortoise.contrib import test from tortoise.exceptions import IntegrityError, NoValuesFetched, OperationalError from tortoise.queryset import QuerySet -class TestForeignKeyFieldWithUnique(test.TestCase): - async def test_student__empty(self): - with self.assertRaises(IntegrityError): - await testmodels.Student.create() - - async def test_student__create_by_id(self): - school = await testmodels.School.create(id=1024, name="School1") - student = await testmodels.Student.create(name="Sang-Heon Jeon", school_id=school.id) - self.assertEqual(student.school_id, school.id) - self.assertEqual((await school.students.all())[0], student) - - async def test_student__create_by_name(self): - school = await testmodels.School.create(id=1024, name="School1") - student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - await student.fetch_related("school") - self.assertEqual(student.school, school) - self.assertEqual((await school.students.all())[0], student) - - async def test_student__by_name__created_prefetched(self): - school = await testmodels.School.create(id=1024, name="School1") - student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - self.assertEqual(student.school, school) - self.assertEqual((await school.students.all())[0], student) - - async def test_student__by_name__unfetched(self): - school = await testmodels.School.create(id=1024, name="School1") - student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - student = await testmodels.Student.get(id=student.id) - self.assertIsInstance(student.school, QuerySet) - - async def test_student__by_name__re_awaited(self): - school = await testmodels.School.create(id=1024, name="School1") - student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - await student.fetch_related("school") - self.assertEqual(student.school, school) - self.assertEqual(await student.school, school) - - async def test_student__by_name__awaited(self): - school = await testmodels.School.create(id=1024, name="School1") - student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - student = await testmodels.Student.get(id=student.id) - self.assertEqual(await student.school, school) - self.assertEqual((await school.students.all())[0], student) - - async def test_update_by_name(self): - school = await testmodels.School.create(id=1024, name="School1") - school2 = await testmodels.School.create(id=2048, name="School2") - student0 = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - - await testmodels.Student.filter(id=student0.id).update(school=school2) - student = await testmodels.Student.get(id=student0.id) - - await student.fetch_related("school") - self.assertEqual(student.school, school2) - self.assertEqual(await school.students.all(), []) - self.assertEqual((await school2.students.all())[0], student) - - async def test_update_by_id(self): - school = await testmodels.School.create(id=1024, name="School1") - school2 = await testmodels.School.create(id=2048, name="School2") - student0 = await testmodels.Student.create(name="Sang-Heon Jeon", school_id=school.id) - - await testmodels.Student.filter(id=student0.id).update(school_id=school2.id) - student = await testmodels.Student.get(id=student0.id) - - self.assertEqual(student.school_id, school2.id) - self.assertEqual(await school.students.all(), []) - self.assertEqual((await school2.students.all())[0], student) - - async def test_delete_by_name(self): - school = await testmodels.School.create(id=1024, name="School1") - student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - del student.school - with self.assertRaises(IntegrityError): - await student.save() - - async def test_student__uninstantiated_create(self): - school = await testmodels.School(id=1024, name="School1") - with self.assertRaisesRegex(OperationalError, "You should first call .save()"): - await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - - async def test_student__uninstantiated_iterate(self): - school = await testmodels.School(id=1024, name="School1") - with self.assertRaisesRegex( - OperationalError, "This objects hasn't been instanced, call .save()" - ): - async for _ in school.students: - pass - - async def test_student__uninstantiated_await(self): - school = await testmodels.School(id=1024, name="School1") - with self.assertRaisesRegex( - OperationalError, "This objects hasn't been instanced, call .save()" - ): - await school.students - - async def test_student__unfetched_contains(self): - school = await testmodels.School.create(id=1024, name="School1") - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - "a" in school.students # pylint: disable=W0104 - - async def test_stduent__unfetched_iter(self): - school = await testmodels.School.create(id=1024, name="School1") - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - for _ in school.students: - pass - - async def test_student__unfetched_len(self): - school = await testmodels.School.create(id=1024, name="School1") - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - len(school.students) - - async def test_student__unfetched_bool(self): - school = await testmodels.School.create(id=1024, name="School1") - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - bool(school.students) - - async def test_student__unfetched_getitem(self): - school = await testmodels.School.create(id=1024, name="School1") - with self.assertRaisesRegex( - NoValuesFetched, - "No values were fetched for this relation, first use .fetch_related()", - ): - school.students[0] # pylint: disable=W0104 - - async def test_student__instantiated_create(self): - school = await testmodels.School.create(id=1024, name="School1") +@pytest.mark.asyncio +async def test_student__empty(db): + with pytest.raises(IntegrityError): + await testmodels.Student.create() + + +@pytest.mark.asyncio +async def test_student__create_by_id(db): + school = await testmodels.School.create(id=1024, name="School1") + student = await testmodels.Student.create(name="Sang-Heon Jeon", school_id=school.id) + assert student.school_id == school.id + assert (await school.students.all())[0] == student + + +@pytest.mark.asyncio +async def test_student__create_by_name(db): + school = await testmodels.School.create(id=1024, name="School1") + student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) + await student.fetch_related("school") + assert student.school == school + assert (await school.students.all())[0] == student + + +@pytest.mark.asyncio +async def test_student__by_name__created_prefetched(db): + school = await testmodels.School.create(id=1024, name="School1") + student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) + assert student.school == school + assert (await school.students.all())[0] == student + + +@pytest.mark.asyncio +async def test_student__by_name__unfetched(db): + school = await testmodels.School.create(id=1024, name="School1") + student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) + student = await testmodels.Student.get(id=student.id) + assert isinstance(student.school, QuerySet) + + +@pytest.mark.asyncio +async def test_student__by_name__re_awaited(db): + school = await testmodels.School.create(id=1024, name="School1") + student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) + await student.fetch_related("school") + assert student.school == school + assert await student.school == school + + +@pytest.mark.asyncio +async def test_student__by_name__awaited(db): + school = await testmodels.School.create(id=1024, name="School1") + student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) + student = await testmodels.Student.get(id=student.id) + assert await student.school == school + assert (await school.students.all())[0] == student + + +@pytest.mark.asyncio +async def test_update_by_name(db): + school = await testmodels.School.create(id=1024, name="School1") + school2 = await testmodels.School.create(id=2048, name="School2") + student0 = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) + + await testmodels.Student.filter(id=student0.id).update(school=school2) + student = await testmodels.Student.get(id=student0.id) + + await student.fetch_related("school") + assert student.school == school2 + assert await school.students.all() == [] + assert (await school2.students.all())[0] == student + + +@pytest.mark.asyncio +async def test_update_by_id(db): + school = await testmodels.School.create(id=1024, name="School1") + school2 = await testmodels.School.create(id=2048, name="School2") + student0 = await testmodels.Student.create(name="Sang-Heon Jeon", school_id=school.id) + + await testmodels.Student.filter(id=student0.id).update(school_id=school2.id) + student = await testmodels.Student.get(id=student0.id) + + assert student.school_id == school2.id + assert await school.students.all() == [] + assert (await school2.students.all())[0] == student + + +@pytest.mark.asyncio +async def test_delete_by_name(db): + school = await testmodels.School.create(id=1024, name="School1") + student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) + del student.school + with pytest.raises(IntegrityError): + await student.save() + + +@pytest.mark.asyncio +async def test_student__uninstantiated_create(db): + school = await testmodels.School(id=1024, name="School1") + with pytest.raises(OperationalError, match="You should first call .save()"): await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - async def test_student__instantiated_iterate(self): - school = await testmodels.School.create(id=1024, name="School1") + +@pytest.mark.asyncio +async def test_student__uninstantiated_iterate(db): + school = await testmodels.School(id=1024, name="School1") + with pytest.raises(OperationalError, match="This objects hasn't been instanced, call .save()"): async for _ in school.students: pass - async def test_student__instantiated_await(self): - school = await testmodels.School.create(id=1024, name="School1") + +@pytest.mark.asyncio +async def test_student__uninstantiated_await(db): + school = await testmodels.School(id=1024, name="School1") + with pytest.raises(OperationalError, match="This objects hasn't been instanced, call .save()"): await school.students - async def test_student__fetched_contains(self): - school = await testmodels.School.create(id=1024, name="School1") - student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - await school.fetch_related("students") - self.assertTrue(student in school.students) - async def test_student__fetched_iter(self): - school = await testmodels.School.create(id=1024, name="School1") - student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - await school.fetch_related("students") - self.assertEqual(list(school.students), [student]) +@pytest.mark.asyncio +async def test_student__unfetched_contains(db): + school = await testmodels.School.create(id=1024, name="School1") + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + "a" in school.students # pylint: disable=W0104 + + +@pytest.mark.asyncio +async def test_stduent__unfetched_iter(db): + school = await testmodels.School.create(id=1024, name="School1") + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + for _ in school.students: + pass - async def test_student__fetched_len(self): - school = await testmodels.School.create(id=1024, name="School1") - await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - await school.fetch_related("students") - self.assertEqual(len(school.students), 1) - async def test_student__fetched_bool(self): - school = await testmodels.School.create(id=1024, name="School1") - await school.fetch_related("students") - self.assertFalse(bool(school.students)) - await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - await school.fetch_related("students") - self.assertTrue(bool(school.students)) - - async def test_student__fetched_getitem(self): - school = await testmodels.School.create(id=1024, name="School1") - student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) - await school.fetch_related("students") - self.assertEqual(school.students[0], student) - - with self.assertRaises(IndexError): - school.students[1] # pylint: disable=W0104 - - async def test_student__filter(self): - school = await testmodels.School.create(id=1024, name="School1") - student1 = await testmodels.Student.create(name="Sang-Heon Jeon1", school=school) - student2 = await testmodels.Student.create(name="Sang-Heon Jeon2", school=school) - self.assertEqual(await school.students.filter(name="Sang-Heon Jeon1"), [student1]) - self.assertEqual(await school.students.filter(name="Sang-Heon Jeon2"), [student2]) - self.assertEqual(await school.students.filter(name="Sang-Heon Jeon3"), []) - - async def test_student__all(self): - school = await testmodels.School.create(id=1024, name="School1") - student1 = await testmodels.Student.create(name="Sang-Heon Jeon1", school=school) - student2 = await testmodels.Student.create(name="Sang-Heon Jeon2", school=school) - self.assertEqual(set(await school.students.all()), {student1, student2}) - - async def test_student_order_by(self): - school = await testmodels.School.create(id=1024, name="School1") - student1 = await testmodels.Student.create(name="Sang-Heon Jeon1", school=school) - student2 = await testmodels.Student.create(name="Sang-Heon Jeon2", school=school) - self.assertEqual(await school.students.order_by("-name"), [student2, student1]) - self.assertEqual(await school.students.order_by("name"), [student1, student2]) - - async def test_student__limit(self): - school = await testmodels.School.create(id=1024, name="School1") - student1 = await testmodels.Student.create(name="Sang-Heon Jeon1", school=school) - student2 = await testmodels.Student.create(name="Sang-Heon Jeon2", school=school) - await testmodels.Student.create(name="Sang-Heon Jeon3", school=school) - self.assertEqual(await school.students.limit(2).order_by("name"), [student1, student2]) - - async def test_student_offset(self): - school = await testmodels.School.create(id=1024, name="School1") - await testmodels.Student.create(name="Sang-Heon Jeon1", school=school) - student2 = await testmodels.Student.create(name="Sang-Heon Jeon2", school=school) - student3 = await testmodels.Student.create(name="Sang-Heon Jeon3", school=school) - self.assertEqual(await school.students.offset(1).order_by("name"), [student2, student3]) +@pytest.mark.asyncio +async def test_student__unfetched_len(db): + school = await testmodels.School.create(id=1024, name="School1") + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + len(school.students) + + +@pytest.mark.asyncio +async def test_student__unfetched_bool(db): + school = await testmodels.School.create(id=1024, name="School1") + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + bool(school.students) + + +@pytest.mark.asyncio +async def test_student__unfetched_getitem(db): + school = await testmodels.School.create(id=1024, name="School1") + with pytest.raises( + NoValuesFetched, + match="No values were fetched for this relation, first use .fetch_related()", + ): + school.students[0] # pylint: disable=W0104 + + +@pytest.mark.asyncio +async def test_student__instantiated_create(db): + school = await testmodels.School.create(id=1024, name="School1") + await testmodels.Student.create(name="Sang-Heon Jeon", school=school) + + +@pytest.mark.asyncio +async def test_student__instantiated_iterate(db): + school = await testmodels.School.create(id=1024, name="School1") + async for _ in school.students: + pass + + +@pytest.mark.asyncio +async def test_student__instantiated_await(db): + school = await testmodels.School.create(id=1024, name="School1") + await school.students + + +@pytest.mark.asyncio +async def test_student__fetched_contains(db): + school = await testmodels.School.create(id=1024, name="School1") + student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) + await school.fetch_related("students") + assert student in school.students + + +@pytest.mark.asyncio +async def test_student__fetched_iter(db): + school = await testmodels.School.create(id=1024, name="School1") + student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) + await school.fetch_related("students") + assert list(school.students) == [student] + + +@pytest.mark.asyncio +async def test_student__fetched_len(db): + school = await testmodels.School.create(id=1024, name="School1") + await testmodels.Student.create(name="Sang-Heon Jeon", school=school) + await school.fetch_related("students") + assert len(school.students) == 1 + + +@pytest.mark.asyncio +async def test_student__fetched_bool(db): + school = await testmodels.School.create(id=1024, name="School1") + await school.fetch_related("students") + assert not bool(school.students) + await testmodels.Student.create(name="Sang-Heon Jeon", school=school) + await school.fetch_related("students") + assert bool(school.students) + + +@pytest.mark.asyncio +async def test_student__fetched_getitem(db): + school = await testmodels.School.create(id=1024, name="School1") + student = await testmodels.Student.create(name="Sang-Heon Jeon", school=school) + await school.fetch_related("students") + assert school.students[0] == student + + with pytest.raises(IndexError): + school.students[1] # pylint: disable=W0104 + + +@pytest.mark.asyncio +async def test_student__filter(db): + school = await testmodels.School.create(id=1024, name="School1") + student1 = await testmodels.Student.create(name="Sang-Heon Jeon1", school=school) + student2 = await testmodels.Student.create(name="Sang-Heon Jeon2", school=school) + assert await school.students.filter(name="Sang-Heon Jeon1") == [student1] + assert await school.students.filter(name="Sang-Heon Jeon2") == [student2] + assert await school.students.filter(name="Sang-Heon Jeon3") == [] + + +@pytest.mark.asyncio +async def test_student__all(db): + school = await testmodels.School.create(id=1024, name="School1") + student1 = await testmodels.Student.create(name="Sang-Heon Jeon1", school=school) + student2 = await testmodels.Student.create(name="Sang-Heon Jeon2", school=school) + assert set(await school.students.all()) == {student1, student2} + + +@pytest.mark.asyncio +async def test_student_order_by(db): + school = await testmodels.School.create(id=1024, name="School1") + student1 = await testmodels.Student.create(name="Sang-Heon Jeon1", school=school) + student2 = await testmodels.Student.create(name="Sang-Heon Jeon2", school=school) + assert await school.students.order_by("-name") == [student2, student1] + assert await school.students.order_by("name") == [student1, student2] + + +@pytest.mark.asyncio +async def test_student__limit(db): + school = await testmodels.School.create(id=1024, name="School1") + student1 = await testmodels.Student.create(name="Sang-Heon Jeon1", school=school) + student2 = await testmodels.Student.create(name="Sang-Heon Jeon2", school=school) + await testmodels.Student.create(name="Sang-Heon Jeon3", school=school) + assert await school.students.limit(2).order_by("name") == [student1, student2] + + +@pytest.mark.asyncio +async def test_student_offset(db): + school = await testmodels.School.create(id=1024, name="School1") + await testmodels.Student.create(name="Sang-Heon Jeon1", school=school) + student2 = await testmodels.Student.create(name="Sang-Heon Jeon2", school=school) + student3 = await testmodels.Student.create(name="Sang-Heon Jeon3", school=school) + assert await school.students.offset(1).order_by("name") == [student2, student3] diff --git a/tests/fields/test_float.py b/tests/fields/test_float.py index ed6367d92..e6c521395 100644 --- a/tests/fields/test_float.py +++ b/tests/fields/test_float.py @@ -1,56 +1,71 @@ from decimal import Decimal +import pytest + from tests import testmodels -from tortoise.contrib import test from tortoise.exceptions import IntegrityError from tortoise.expressions import F -class TestFloatFields(test.TestCase): - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels.FloatFields.create() - - async def test_create(self): - obj0 = await testmodels.FloatFields.create(floatnum=1.23) - obj = await testmodels.FloatFields.get(id=obj0.id) - self.assertEqual(obj.floatnum, 1.23) - self.assertNotEqual(Decimal(obj.floatnum), Decimal("1.23")) - self.assertEqual(obj.floatnum_null, None) - await obj.save() - obj2 = await testmodels.FloatFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_update(self): - obj0 = await testmodels.FloatFields.create(floatnum=1.23) - await testmodels.FloatFields.filter(id=obj0.id).update(floatnum=2.34) - obj = await testmodels.FloatFields.get(id=obj0.id) - self.assertEqual(obj.floatnum, 2.34) - self.assertNotEqual(Decimal(obj.floatnum), Decimal("2.34")) - self.assertEqual(obj.floatnum_null, None) - - async def test_cast_int(self): - obj0 = await testmodels.FloatFields.create(floatnum=123) - obj = await testmodels.FloatFields.get(id=obj0.id) - self.assertEqual(obj.floatnum, 123) - - async def test_cast_decimal(self): - obj0 = await testmodels.FloatFields.create(floatnum=Decimal("1.23")) - obj = await testmodels.FloatFields.get(id=obj0.id) - self.assertEqual(obj.floatnum, 1.23) - - async def test_values(self): - obj0 = await testmodels.FloatFields.create(floatnum=1.23) - values = await testmodels.FloatFields.filter(id=obj0.id).values("floatnum") - self.assertEqual(values[0]["floatnum"], 1.23) - - async def test_values_list(self): - obj0 = await testmodels.FloatFields.create(floatnum=1.23) - values = await testmodels.FloatFields.filter(id=obj0.id).values_list("floatnum") - self.assertEqual(list(values[0]), [1.23]) - - async def test_f_expression(self): - obj0 = await testmodels.FloatFields.create(floatnum=1.23) - await obj0.filter(id=obj0.id).update(floatnum=F("floatnum") + 0.01) - obj1 = await testmodels.FloatFields.get(id=obj0.id) - self.assertEqual(obj1.floatnum, 1.24) +@pytest.mark.asyncio +async def test_empty(db): + with pytest.raises(IntegrityError): + await testmodels.FloatFields.create() + + +@pytest.mark.asyncio +async def test_create(db): + obj0 = await testmodels.FloatFields.create(floatnum=1.23) + obj = await testmodels.FloatFields.get(id=obj0.id) + assert obj.floatnum == 1.23 + assert Decimal(obj.floatnum) != Decimal("1.23") + assert obj.floatnum_null is None + await obj.save() + obj2 = await testmodels.FloatFields.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_update(db): + obj0 = await testmodels.FloatFields.create(floatnum=1.23) + await testmodels.FloatFields.filter(id=obj0.id).update(floatnum=2.34) + obj = await testmodels.FloatFields.get(id=obj0.id) + assert obj.floatnum == 2.34 + assert Decimal(obj.floatnum) != Decimal("2.34") + assert obj.floatnum_null is None + + +@pytest.mark.asyncio +async def test_cast_int(db): + obj0 = await testmodels.FloatFields.create(floatnum=123) + obj = await testmodels.FloatFields.get(id=obj0.id) + assert obj.floatnum == 123 + + +@pytest.mark.asyncio +async def test_cast_decimal(db): + obj0 = await testmodels.FloatFields.create(floatnum=Decimal("1.23")) + obj = await testmodels.FloatFields.get(id=obj0.id) + assert obj.floatnum == 1.23 + + +@pytest.mark.asyncio +async def test_values(db): + obj0 = await testmodels.FloatFields.create(floatnum=1.23) + values = await testmodels.FloatFields.filter(id=obj0.id).values("floatnum") + assert values[0]["floatnum"] == 1.23 + + +@pytest.mark.asyncio +async def test_values_list(db): + obj0 = await testmodels.FloatFields.create(floatnum=1.23) + values = await testmodels.FloatFields.filter(id=obj0.id).values_list("floatnum") + assert list(values[0]) == [1.23] + + +@pytest.mark.asyncio +async def test_f_expression(db): + obj0 = await testmodels.FloatFields.create(floatnum=1.23) + await obj0.filter(id=obj0.id).update(floatnum=F("floatnum") + 0.01) + obj1 = await testmodels.FloatFields.get(id=obj0.id) + assert obj1.floatnum == 1.24 diff --git a/tests/fields/test_int.py b/tests/fields/test_int.py index d5bd2b771..e948e61ea 100644 --- a/tests/fields/test_int.py +++ b/tests/fields/test_int.py @@ -1,147 +1,196 @@ +import pytest + from tests import testmodels -from tortoise.contrib import test from tortoise.exceptions import IntegrityError from tortoise.expressions import F +# ============================================================================ +# TestIntFields +# ============================================================================ + + +@pytest.mark.asyncio +async def test_int_fields_empty(db): + with pytest.raises(IntegrityError): + await testmodels.IntFields.create() + + +@pytest.mark.asyncio +async def test_int_fields_create(db): + obj0 = await testmodels.IntFields.create(intnum=2147483647) + obj = await testmodels.IntFields.get(id=obj0.id) + assert obj.intnum == 2147483647 + assert obj.intnum_null is None + + obj2 = await testmodels.IntFields.get(id=obj.id) + assert obj == obj2 + + await obj.delete() + obj = await testmodels.IntFields.filter(id=obj0.id).first() + assert obj is None + + +@pytest.mark.asyncio +async def test_int_fields_update(db): + obj0 = await testmodels.IntFields.create(intnum=2147483647) + await testmodels.IntFields.filter(id=obj0.id).update(intnum=2147483646) + obj = await testmodels.IntFields.get(id=obj0.id) + assert obj.intnum == 2147483646 + assert obj.intnum_null is None + + +@pytest.mark.asyncio +async def test_int_fields_min(db): + obj0 = await testmodels.IntFields.create(intnum=-2147483648) + obj = await testmodels.IntFields.get(id=obj0.id) + assert obj.intnum == -2147483648 + assert obj.intnum_null is None + + obj2 = await testmodels.IntFields.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_int_fields_cast(db): + obj0 = await testmodels.IntFields.create(intnum="3") + obj = await testmodels.IntFields.get(id=obj0.id) + assert obj.intnum == 3 + + +@pytest.mark.asyncio +async def test_int_fields_values(db): + obj0 = await testmodels.IntFields.create(intnum=1) + values = await testmodels.IntFields.get(id=obj0.id).values("intnum") + assert values["intnum"] == 1 + + +@pytest.mark.asyncio +async def test_int_fields_values_list(db): + obj0 = await testmodels.IntFields.create(intnum=1) + values = await testmodels.IntFields.get(id=obj0.id).values_list("intnum", flat=True) + assert values == 1 + + +@pytest.mark.asyncio +async def test_int_fields_f_expression(db): + obj0 = await testmodels.IntFields.create(intnum=1) + await obj0.filter(id=obj0.id).update(intnum=F("intnum") + 1) + obj1 = await testmodels.IntFields.get(id=obj0.id) + assert obj1.intnum == 2 + + +# ============================================================================ +# TestSmallIntFields +# ============================================================================ + + +@pytest.mark.asyncio +async def test_small_int_fields_empty(db): + with pytest.raises(IntegrityError): + await testmodels.SmallIntFields.create() + + +@pytest.mark.asyncio +async def test_small_int_fields_create(db): + obj0 = await testmodels.SmallIntFields.create(smallintnum=32767) + obj = await testmodels.SmallIntFields.get(id=obj0.id) + assert obj.smallintnum == 32767 + assert obj.smallintnum_null is None + await obj.save() + obj2 = await testmodels.SmallIntFields.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_small_int_fields_min(db): + obj0 = await testmodels.SmallIntFields.create(smallintnum=-32768) + obj = await testmodels.SmallIntFields.get(id=obj0.id) + assert obj.smallintnum == -32768 + assert obj.smallintnum_null is None + await obj.save() + obj2 = await testmodels.SmallIntFields.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_small_int_fields_values(db): + obj0 = await testmodels.SmallIntFields.create(smallintnum=2) + values = await testmodels.SmallIntFields.get(id=obj0.id).values("smallintnum") + assert values["smallintnum"] == 2 + + +@pytest.mark.asyncio +async def test_small_int_fields_values_list(db): + obj0 = await testmodels.SmallIntFields.create(smallintnum=2) + values = await testmodels.SmallIntFields.get(id=obj0.id).values_list("smallintnum", flat=True) + assert values == 2 + + +@pytest.mark.asyncio +async def test_small_int_fields_f_expression(db): + obj0 = await testmodels.SmallIntFields.create(smallintnum=1) + await obj0.filter(id=obj0.id).update(smallintnum=F("smallintnum") + 1) + obj1 = await testmodels.SmallIntFields.get(id=obj0.id) + assert obj1.smallintnum == 2 + + +# ============================================================================ +# TestBigIntFields +# ============================================================================ + + +@pytest.mark.asyncio +async def test_big_int_fields_empty(db): + with pytest.raises(IntegrityError): + await testmodels.BigIntFields.create() + + +@pytest.mark.asyncio +async def test_big_int_fields_create(db): + obj0 = await testmodels.BigIntFields.create(intnum=9223372036854775807) + obj = await testmodels.BigIntFields.get(id=obj0.id) + assert obj.intnum == 9223372036854775807 + assert obj.intnum_null is None + await obj.save() + obj2 = await testmodels.BigIntFields.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_big_int_fields_min(db): + obj0 = await testmodels.BigIntFields.create(intnum=-9223372036854775808) + obj = await testmodels.BigIntFields.get(id=obj0.id) + assert obj.intnum == -9223372036854775808 + assert obj.intnum_null is None + await obj.save() + obj2 = await testmodels.BigIntFields.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_big_int_fields_cast(db): + obj0 = await testmodels.BigIntFields.create(intnum="3") + obj = await testmodels.BigIntFields.get(id=obj0.id) + assert obj.intnum == 3 + + +@pytest.mark.asyncio +async def test_big_int_fields_values(db): + obj0 = await testmodels.BigIntFields.create(intnum=1) + values = await testmodels.BigIntFields.get(id=obj0.id).values("intnum") + assert values["intnum"] == 1 + + +@pytest.mark.asyncio +async def test_big_int_fields_values_list(db): + obj0 = await testmodels.BigIntFields.create(intnum=1) + values = await testmodels.BigIntFields.get(id=obj0.id).values_list("intnum", flat=True) + assert values == 1 + -class TestIntFields(test.TestCase): - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels.IntFields.create() - - async def test_create(self): - obj0 = await testmodels.IntFields.create(intnum=2147483647) - obj = await testmodels.IntFields.get(id=obj0.id) - self.assertEqual(obj.intnum, 2147483647) - self.assertEqual(obj.intnum_null, None) - - obj2 = await testmodels.IntFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - await obj.delete() - obj = await testmodels.IntFields.filter(id=obj0.id).first() - self.assertEqual(obj, None) - - async def test_update(self): - obj0 = await testmodels.IntFields.create(intnum=2147483647) - await testmodels.IntFields.filter(id=obj0.id).update(intnum=2147483646) - obj = await testmodels.IntFields.get(id=obj0.id) - self.assertEqual(obj.intnum, 2147483646) - self.assertEqual(obj.intnum_null, None) - - async def test_min(self): - obj0 = await testmodels.IntFields.create(intnum=-2147483648) - obj = await testmodels.IntFields.get(id=obj0.id) - self.assertEqual(obj.intnum, -2147483648) - self.assertEqual(obj.intnum_null, None) - - obj2 = await testmodels.IntFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_cast(self): - obj0 = await testmodels.IntFields.create(intnum="3") - obj = await testmodels.IntFields.get(id=obj0.id) - self.assertEqual(obj.intnum, 3) - - async def test_values(self): - obj0 = await testmodels.IntFields.create(intnum=1) - values = await testmodels.IntFields.get(id=obj0.id).values("intnum") - self.assertEqual(values["intnum"], 1) - - async def test_values_list(self): - obj0 = await testmodels.IntFields.create(intnum=1) - values = await testmodels.IntFields.get(id=obj0.id).values_list("intnum", flat=True) - self.assertEqual(values, 1) - - async def test_f_expression(self): - obj0 = await testmodels.IntFields.create(intnum=1) - await obj0.filter(id=obj0.id).update(intnum=F("intnum") + 1) - obj1 = await testmodels.IntFields.get(id=obj0.id) - self.assertEqual(obj1.intnum, 2) - - -class TestSmallIntFields(test.TestCase): - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels.SmallIntFields.create() - - async def test_create(self): - obj0 = await testmodels.SmallIntFields.create(smallintnum=32767) - obj = await testmodels.SmallIntFields.get(id=obj0.id) - self.assertEqual(obj.smallintnum, 32767) - self.assertEqual(obj.smallintnum_null, None) - await obj.save() - obj2 = await testmodels.SmallIntFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_min(self): - obj0 = await testmodels.SmallIntFields.create(smallintnum=-32768) - obj = await testmodels.SmallIntFields.get(id=obj0.id) - self.assertEqual(obj.smallintnum, -32768) - self.assertEqual(obj.smallintnum_null, None) - await obj.save() - obj2 = await testmodels.SmallIntFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_values(self): - obj0 = await testmodels.SmallIntFields.create(smallintnum=2) - values = await testmodels.SmallIntFields.get(id=obj0.id).values("smallintnum") - self.assertEqual(values["smallintnum"], 2) - - async def test_values_list(self): - obj0 = await testmodels.SmallIntFields.create(smallintnum=2) - values = await testmodels.SmallIntFields.get(id=obj0.id).values_list( - "smallintnum", flat=True - ) - self.assertEqual(values, 2) - - async def test_f_expression(self): - obj0 = await testmodels.SmallIntFields.create(smallintnum=1) - await obj0.filter(id=obj0.id).update(smallintnum=F("smallintnum") + 1) - obj1 = await testmodels.SmallIntFields.get(id=obj0.id) - self.assertEqual(obj1.smallintnum, 2) - - -class TestBigIntFields(test.TestCase): - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels.BigIntFields.create() - - async def test_create(self): - obj0 = await testmodels.BigIntFields.create(intnum=9223372036854775807) - obj = await testmodels.BigIntFields.get(id=obj0.id) - self.assertEqual(obj.intnum, 9223372036854775807) - self.assertEqual(obj.intnum_null, None) - await obj.save() - obj2 = await testmodels.BigIntFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_min(self): - obj0 = await testmodels.BigIntFields.create(intnum=-9223372036854775808) - obj = await testmodels.BigIntFields.get(id=obj0.id) - self.assertEqual(obj.intnum, -9223372036854775808) - self.assertEqual(obj.intnum_null, None) - await obj.save() - obj2 = await testmodels.BigIntFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_cast(self): - obj0 = await testmodels.BigIntFields.create(intnum="3") - obj = await testmodels.BigIntFields.get(id=obj0.id) - self.assertEqual(obj.intnum, 3) - - async def test_values(self): - obj0 = await testmodels.BigIntFields.create(intnum=1) - values = await testmodels.BigIntFields.get(id=obj0.id).values("intnum") - self.assertEqual(values["intnum"], 1) - - async def test_values_list(self): - obj0 = await testmodels.BigIntFields.create(intnum=1) - values = await testmodels.BigIntFields.get(id=obj0.id).values_list("intnum", flat=True) - self.assertEqual(values, 1) - - async def test_f_expression(self): - obj0 = await testmodels.BigIntFields.create(intnum=1) - await obj0.filter(id=obj0.id).update(intnum=F("intnum") + 1) - obj1 = await testmodels.BigIntFields.get(id=obj0.id) - self.assertEqual(obj1.intnum, 2) +@pytest.mark.asyncio +async def test_big_int_fields_f_expression(db): + obj0 = await testmodels.BigIntFields.create(intnum=1) + await obj0.filter(id=obj0.id).update(intnum=F("intnum") + 1) + obj1 = await testmodels.BigIntFields.get(id=obj0.id) + assert obj1.intnum == 2 diff --git a/tests/fields/test_json.py b/tests/fields/test_json.py index 916df0943..4b6d2a0a4 100644 --- a/tests/fields/test_json.py +++ b/tests/fields/test_json.py @@ -1,5 +1,7 @@ +import pytest + from tests import testmodels -from tortoise.contrib import test +from tortoise.contrib.test import requireCapability from tortoise.contrib.test.condition import In from tortoise.exceptions import ( ConfigurationError, @@ -10,276 +12,330 @@ from tortoise.fields import JSONField -class TestJSONFields(test.TestCase): - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels.JSONFields.create() +@pytest.mark.asyncio +async def test_empty(db): + """Test that creating without required JSON field raises IntegrityError.""" + with pytest.raises(IntegrityError): + await testmodels.JSONFields.create() - async def test_create(self): - obj0 = await testmodels.JSONFields.create(data={"some": ["text", 3]}) - obj = await testmodels.JSONFields.get(id=obj0.id) - self.assertEqual(obj.data, {"some": ["text", 3]}) - self.assertEqual(obj.data_null, None) - await obj.save() - obj2 = await testmodels.JSONFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_error(self): - with self.assertRaises(FieldError): - await testmodels.JSONFields.create(data='{"some": ') - - obj = await testmodels.JSONFields.create(data='{"some": ["text", 3]}') - with self.assertRaises(FieldError): - await testmodels.JSONFields.filter(pk=obj.pk).update(data='{"some": ') - - with self.assertRaises(FieldError): - obj.data = "error json" - await obj.save() - - async def test_update(self): - obj0 = await testmodels.JSONFields.create(data={"some": ["text", 3]}) - await testmodels.JSONFields.filter(id=obj0.id).update(data={"other": ["text", 5]}) - obj = await testmodels.JSONFields.get(id=obj0.id) - self.assertEqual(obj.data, {"other": ["text", 5]}) - self.assertEqual(obj.data_null, None) - - async def test_dict_str(self): - obj0 = await testmodels.JSONFields.create(data={"some": ["text", 3]}) - - obj = await testmodels.JSONFields.get(id=obj0.id) - self.assertEqual(obj.data, {"some": ["text", 3]}) - - await testmodels.JSONFields.filter(id=obj0.id).update(data='{"other": ["text", 5]}') - obj = await testmodels.JSONFields.get(id=obj0.id) - self.assertEqual(obj.data, {"other": ["text", 5]}) - - async def test_list_str(self): - obj = await testmodels.JSONFields.create(data='["text", 3]') - obj0 = await testmodels.JSONFields.get(id=obj.id) - self.assertEqual(obj0.data, ["text", 3]) - - await testmodels.JSONFields.filter(id=obj.id).update(data='["text", 5]') - obj0 = await testmodels.JSONFields.get(id=obj.id) - self.assertEqual(obj0.data, ["text", 5]) - - async def test_list(self): - obj0 = await testmodels.JSONFields.create(data=["text", 3]) - obj = await testmodels.JSONFields.get(id=obj0.id) - self.assertEqual(obj.data, ["text", 3]) - self.assertEqual(obj.data_null, None) - await obj.save() - obj2 = await testmodels.JSONFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - @test.requireCapability(dialect=In("mysql", "postgres")) - async def test_list_contains(self): - await testmodels.JSONFields.create(data=["text", 3, {"msg": "msg2"}]) - obj = await testmodels.JSONFields.filter(data__contains=[{"msg": "msg2"}]).first() - self.assertEqual(obj.data, ["text", 3, {"msg": "msg2"}]) + +@pytest.mark.asyncio +async def test_create(db): + """Test JSON field creation and retrieval.""" + obj0 = await testmodels.JSONFields.create(data={"some": ["text", 3]}) + obj = await testmodels.JSONFields.get(id=obj0.id) + assert obj.data == {"some": ["text", 3]} + assert obj.data_null is None + await obj.save() + obj2 = await testmodels.JSONFields.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_error(db): + """Test that invalid JSON raises FieldError.""" + with pytest.raises(FieldError): + await testmodels.JSONFields.create(data='{"some": ') + + obj = await testmodels.JSONFields.create(data='{"some": ["text", 3]}') + with pytest.raises(FieldError): + await testmodels.JSONFields.filter(pk=obj.pk).update(data='{"some": ') + + with pytest.raises(FieldError): + obj.data = "error json" await obj.save() - obj2 = await testmodels.JSONFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - @test.requireCapability(dialect=In("mysql", "postgres")) - async def test_list_contained_by(self): - obj0 = await testmodels.JSONFields.create(data=["text"]) - obj1 = await testmodels.JSONFields.create(data=["tortoise", "msg"]) - obj2 = await testmodels.JSONFields.create(data=["tortoise"]) - obj3 = await testmodels.JSONFields.create(data=["new_message", "some_message"]) - objs = set( - await testmodels.JSONFields.filter(data__contained_by=["text", "tortoise", "msg"]) - ) - created_objs = {obj0, obj1, obj2} - self.assertSetEqual(created_objs, objs) - self.assertTrue(obj3 not in objs) - - @test.requireCapability(dialect=In("mysql", "postgres")) - async def test_filter(self): - obj0 = await testmodels.JSONFields.create( - data={ - "breed": "labrador", - "owner": { - "name": "Bob", - "last": None, - "other_pets": [ - { - "name": "Fishy", - } - ], - }, - } - ) - obj1 = await testmodels.JSONFields.create( - data={ - "breed": "husky", - "owner": { - "name": "Goldast", - "last": None, - "other_pets": [ - { - "name": None, - } - ], - }, - } - ) - obj = await testmodels.JSONFields.get(data__filter={"breed": "labrador"}) - obj2 = await testmodels.JSONFields.get(data__filter={"owner__name": "Goldast"}) - obj3 = await testmodels.JSONFields.get(data__filter={"owner__other_pets__0__name": "Fishy"}) - - self.assertEqual(obj0, obj) - self.assertEqual(obj1, obj2) - self.assertEqual(obj0, obj3) - - with self.assertRaises(DoesNotExist): - obj = await testmodels.JSONFields.get(data__filter={"breed": "NotFound"}) - with self.assertRaises(DoesNotExist): - await testmodels.JSONFields.get(data__filter={"owner__other_pets__0__name": "NotFound"}) - - @test.requireCapability(dialect=In("mysql", "postgres")) - async def test_filter_not_condition(self): - obj0 = await testmodels.JSONFields.create( - data={ - "breed": "labrador", - "owner": { - "name": "Bob", - "last": None, - "other_pets": [ - { - "name": "Fishy", - } - ], - }, - } - ) - obj1 = await testmodels.JSONFields.create( - data={ - "breed": "husky", - "owner": { - "name": "Goldast", - "last": None, - "other_pets": [ - { - "name": "Fishy", - } - ], - }, - } - ) - - obj2 = await testmodels.JSONFields.get(data__filter={"breed__not": "husky"}) - obj3 = await testmodels.JSONFields.get(data__filter={"breed__not": "labrador"}) - self.assertEqual(obj0, obj2) - self.assertEqual(obj1, obj3) - - @test.requireCapability(dialect=In("mysql", "postgres")) - async def test_filter_is_null_condition(self): - obj0 = await testmodels.JSONFields.create( - data={ - "breed": "labrador", - "owner": { - "name": "Boby", - "last": "Cloud", - "other_pets": [ - { - "name": "Fishy", - } - ], - }, - } - ) - - obj1 = await testmodels.JSONFields.create( - data={ - "breed": "labrador", - "owner": { - "name": None, - "last": "Cloud", - "other_pets": [ - { - "name": "Fishy", - } - ], - }, - } - ) - - obj2 = await testmodels.JSONFields.get(data__filter={"owner__name__isnull": False}) - obj3 = await testmodels.JSONFields.get(data__filter={"owner__name__isnull": True}) - self.assertEqual(obj0, obj2) - self.assertEqual(obj1, obj3) - - @test.requireCapability(dialect=In("mysql", "postgres")) - async def test_filter_not_is_null_condition(self): - obj0 = await testmodels.JSONFields.create( - data={ - "breed": "labrador", - "owner": { - "name": "Boby", - "last": "Cloud", - "other_pets": [ - { - "name": "Fishy", - } - ], - }, - } - ) - - obj1 = await testmodels.JSONFields.create( - data={ - "breed": "labrador", - "owner": { - "name": None, - "last": "Cloud", - "other_pets": [ - { - "name": "Fishy", - } - ], - }, - } - ) - - obj2 = await testmodels.JSONFields.get(data__filter={"owner__name__not_isnull": True}) - obj3 = await testmodels.JSONFields.get(data__filter={"owner__name__not_isnull": False}) - self.assertEqual(obj0, obj2) - self.assertEqual(obj1, obj3) - - async def test_values(self): - obj0 = await testmodels.JSONFields.create(data={"some": ["text", 3]}) - values = await testmodels.JSONFields.filter(id=obj0.id).values("data") - self.assertEqual(values[0]["data"], {"some": ["text", 3]}) - - async def test_values_list(self): - obj0 = await testmodels.JSONFields.create(data={"some": ["text", 3]}) - values = await testmodels.JSONFields.filter(id=obj0.id).values_list("data", flat=True) - self.assertEqual(values[0], {"some": ["text", 3]}) - - def test_unique_fail(self): - with self.assertRaisesRegex(ConfigurationError, "can't be indexed"): - JSONField(unique=True) - - def test_index_fail(self): - with self.assertRaisesRegex(ConfigurationError, "can't be indexed"): - with self.assertWarnsRegex( - DeprecationWarning, "`index` is deprecated, please use `db_index` instead" - ): - JSONField(index=True) - with self.assertRaisesRegex(ConfigurationError, "can't be indexed"): - JSONField(db_index=True) - - async def test_validate_str(self): - obj0 = await testmodels.JSONFields.create(data=[], data_validate='["text", 5]') - obj = await testmodels.JSONFields.get(id=obj0.id) - self.assertEqual(obj.data_validate, ["text", 5]) - - async def test_validate_dict(self): - obj0 = await testmodels.JSONFields.create(data=[], data_validate={"some": ["text", 3]}) - obj = await testmodels.JSONFields.get(id=obj0.id) - self.assertEqual(obj.data_validate, {"some": ["text", 3]}) - - async def test_validate_list(self): - obj0 = await testmodels.JSONFields.create(data=[], data_validate=["text", 3]) - obj = await testmodels.JSONFields.get(id=obj0.id) - self.assertEqual(obj.data_validate, ["text", 3]) + + +@pytest.mark.asyncio +async def test_update(db): + """Test JSON field update.""" + obj0 = await testmodels.JSONFields.create(data={"some": ["text", 3]}) + await testmodels.JSONFields.filter(id=obj0.id).update(data={"other": ["text", 5]}) + obj = await testmodels.JSONFields.get(id=obj0.id) + assert obj.data == {"other": ["text", 5]} + assert obj.data_null is None + + +@pytest.mark.asyncio +async def test_dict_str(db): + """Test JSON field with dict from string.""" + obj0 = await testmodels.JSONFields.create(data={"some": ["text", 3]}) + + obj = await testmodels.JSONFields.get(id=obj0.id) + assert obj.data == {"some": ["text", 3]} + + await testmodels.JSONFields.filter(id=obj0.id).update(data='{"other": ["text", 5]}') + obj = await testmodels.JSONFields.get(id=obj0.id) + assert obj.data == {"other": ["text", 5]} + + +@pytest.mark.asyncio +async def test_list_str(db): + """Test JSON field with list from string.""" + obj = await testmodels.JSONFields.create(data='["text", 3]') + obj0 = await testmodels.JSONFields.get(id=obj.id) + assert obj0.data == ["text", 3] + + await testmodels.JSONFields.filter(id=obj.id).update(data='["text", 5]') + obj0 = await testmodels.JSONFields.get(id=obj.id) + assert obj0.data == ["text", 5] + + +@pytest.mark.asyncio +async def test_list(db): + """Test JSON field with list data.""" + obj0 = await testmodels.JSONFields.create(data=["text", 3]) + obj = await testmodels.JSONFields.get(id=obj0.id) + assert obj.data == ["text", 3] + assert obj.data_null is None + await obj.save() + obj2 = await testmodels.JSONFields.get(id=obj.id) + assert obj == obj2 + + +@requireCapability(dialect=In("mysql", "postgres")) +@pytest.mark.asyncio +async def test_list_contains(db): + """Test JSON contains filter on list.""" + await testmodels.JSONFields.create(data=["text", 3, {"msg": "msg2"}]) + obj = await testmodels.JSONFields.filter(data__contains=[{"msg": "msg2"}]).first() + assert obj.data == ["text", 3, {"msg": "msg2"}] + await obj.save() + obj2 = await testmodels.JSONFields.get(id=obj.id) + assert obj == obj2 + + +@requireCapability(dialect=In("mysql", "postgres")) +@pytest.mark.asyncio +async def test_list_contained_by(db): + """Test JSON contained_by filter on list.""" + obj0 = await testmodels.JSONFields.create(data=["text"]) + obj1 = await testmodels.JSONFields.create(data=["tortoise", "msg"]) + obj2 = await testmodels.JSONFields.create(data=["tortoise"]) + obj3 = await testmodels.JSONFields.create(data=["new_message", "some_message"]) + objs = set(await testmodels.JSONFields.filter(data__contained_by=["text", "tortoise", "msg"])) + created_objs = {obj0, obj1, obj2} + assert created_objs == objs + assert obj3 not in objs + + +@requireCapability(dialect=In("mysql", "postgres")) +@pytest.mark.asyncio +async def test_filter(db): + """Test JSON filter with nested data.""" + obj0 = await testmodels.JSONFields.create( + data={ + "breed": "labrador", + "owner": { + "name": "Bob", + "last": None, + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + obj1 = await testmodels.JSONFields.create( + data={ + "breed": "husky", + "owner": { + "name": "Goldast", + "last": None, + "other_pets": [ + { + "name": None, + } + ], + }, + } + ) + obj = await testmodels.JSONFields.get(data__filter={"breed": "labrador"}) + obj2 = await testmodels.JSONFields.get(data__filter={"owner__name": "Goldast"}) + obj3 = await testmodels.JSONFields.get(data__filter={"owner__other_pets__0__name": "Fishy"}) + + assert obj0 == obj + assert obj1 == obj2 + assert obj0 == obj3 + + with pytest.raises(DoesNotExist): + await testmodels.JSONFields.get(data__filter={"breed": "NotFound"}) + with pytest.raises(DoesNotExist): + await testmodels.JSONFields.get(data__filter={"owner__other_pets__0__name": "NotFound"}) + + +@requireCapability(dialect=In("mysql", "postgres")) +@pytest.mark.asyncio +async def test_filter_not_condition(db): + """Test JSON filter with not condition.""" + obj0 = await testmodels.JSONFields.create( + data={ + "breed": "labrador", + "owner": { + "name": "Bob", + "last": None, + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + obj1 = await testmodels.JSONFields.create( + data={ + "breed": "husky", + "owner": { + "name": "Goldast", + "last": None, + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + + obj2 = await testmodels.JSONFields.get(data__filter={"breed__not": "husky"}) + obj3 = await testmodels.JSONFields.get(data__filter={"breed__not": "labrador"}) + assert obj0 == obj2 + assert obj1 == obj3 + + +@requireCapability(dialect=In("mysql", "postgres")) +@pytest.mark.asyncio +async def test_filter_is_null_condition(db): + """Test JSON filter with isnull condition.""" + obj0 = await testmodels.JSONFields.create( + data={ + "breed": "labrador", + "owner": { + "name": "Boby", + "last": "Cloud", + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + + obj1 = await testmodels.JSONFields.create( + data={ + "breed": "labrador", + "owner": { + "name": None, + "last": "Cloud", + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + + obj2 = await testmodels.JSONFields.get(data__filter={"owner__name__isnull": False}) + obj3 = await testmodels.JSONFields.get(data__filter={"owner__name__isnull": True}) + assert obj0 == obj2 + assert obj1 == obj3 + + +@requireCapability(dialect=In("mysql", "postgres")) +@pytest.mark.asyncio +async def test_filter_not_is_null_condition(db): + """Test JSON filter with not_isnull condition.""" + obj0 = await testmodels.JSONFields.create( + data={ + "breed": "labrador", + "owner": { + "name": "Boby", + "last": "Cloud", + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + + obj1 = await testmodels.JSONFields.create( + data={ + "breed": "labrador", + "owner": { + "name": None, + "last": "Cloud", + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + + obj2 = await testmodels.JSONFields.get(data__filter={"owner__name__not_isnull": True}) + obj3 = await testmodels.JSONFields.get(data__filter={"owner__name__not_isnull": False}) + assert obj0 == obj2 + assert obj1 == obj3 + + +@pytest.mark.asyncio +async def test_values(db): + """Test JSON field in values().""" + obj0 = await testmodels.JSONFields.create(data={"some": ["text", 3]}) + values = await testmodels.JSONFields.filter(id=obj0.id).values("data") + assert values[0]["data"] == {"some": ["text", 3]} + + +@pytest.mark.asyncio +async def test_values_list(db): + """Test JSON field in values_list().""" + obj0 = await testmodels.JSONFields.create(data={"some": ["text", 3]}) + values = await testmodels.JSONFields.filter(id=obj0.id).values_list("data", flat=True) + assert values[0] == {"some": ["text", 3]} + + +def test_unique_fail(): + """Test that JSONField cannot be unique.""" + with pytest.raises(ConfigurationError, match="can't be indexed"): + JSONField(unique=True) + + +def test_index_fail(): + """Test that JSONField cannot be indexed.""" + with pytest.raises(ConfigurationError, match="can't be indexed"): + with pytest.warns( + DeprecationWarning, match="`index` is deprecated, please use `db_index` instead" + ): + JSONField(index=True) + with pytest.raises(ConfigurationError, match="can't be indexed"): + JSONField(db_index=True) + + +@pytest.mark.asyncio +async def test_validate_str(db): + """Test JSON field with validate from string.""" + obj0 = await testmodels.JSONFields.create(data=[], data_validate='["text", 5]') + obj = await testmodels.JSONFields.get(id=obj0.id) + assert obj.data_validate == ["text", 5] + + +@pytest.mark.asyncio +async def test_validate_dict(db): + """Test JSON field with validate from dict.""" + obj0 = await testmodels.JSONFields.create(data=[], data_validate={"some": ["text", 3]}) + obj = await testmodels.JSONFields.get(id=obj0.id) + assert obj.data_validate == {"some": ["text", 3]} + + +@pytest.mark.asyncio +async def test_validate_list(db): + """Test JSON field with validate from list.""" + obj0 = await testmodels.JSONFields.create(data=[], data_validate=["text", 3]) + obj = await testmodels.JSONFields.get(id=obj0.id) + assert obj.data_validate == ["text", 3] diff --git a/tests/fields/test_m2m.py b/tests/fields/test_m2m.py index 70357aad4..566295d41 100644 --- a/tests/fields/test_m2m.py +++ b/tests/fields/test_m2m.py @@ -1,128 +1,162 @@ +import pytest + from tests import testmodels -from tortoise.contrib import test from tortoise.exceptions import OperationalError from tortoise.fields import ManyToManyField -class TestManyToManyField(test.TestCase): - async def test_empty(self): - await testmodels.M2MOne.create() +@pytest.mark.asyncio +async def test_empty(db): + """Test creating M2M model without relations.""" + await testmodels.M2MOne.create() - async def test__add(self): - one = await testmodels.M2MOne.create(name="One") - two = await testmodels.M2MTwo.create(name="Two") - await one.two.add(two) - self.assertEqual(await one.two, [two]) - self.assertEqual(await two.one, [one]) - async def test__add__nothing(self): - one = await testmodels.M2MOne.create(name="One") - await one.two.add() +@pytest.mark.asyncio +async def test__add(db): + """Test adding a related object via M2M relation.""" + one = await testmodels.M2MOne.create(name="One") + two = await testmodels.M2MTwo.create(name="Two") + await one.two.add(two) + assert await one.two == [two] + assert await two.one == [one] - async def test__add__reverse(self): - one = await testmodels.M2MOne.create(name="One") - two = await testmodels.M2MTwo.create(name="Two") - await two.one.add(one) - self.assertEqual(await one.two, [two]) - self.assertEqual(await two.one, [one]) - async def test__add__many(self): - one = await testmodels.M2MOne.create(name="One") - two = await testmodels.M2MTwo.create(name="Two") - await one.two.add(two) +@pytest.mark.asyncio +async def test__add__nothing(db): + """Test adding nothing to M2M relation.""" + one = await testmodels.M2MOne.create(name="One") + await one.two.add() + + +@pytest.mark.asyncio +async def test__add__reverse(db): + """Test adding via reverse M2M relation.""" + one = await testmodels.M2MOne.create(name="One") + two = await testmodels.M2MTwo.create(name="Two") + await two.one.add(one) + assert await one.two == [two] + assert await two.one == [one] + + +@pytest.mark.asyncio +async def test__add__many(db): + """Test adding same object multiple times (should be idempotent).""" + one = await testmodels.M2MOne.create(name="One") + two = await testmodels.M2MTwo.create(name="Two") + await one.two.add(two) + await one.two.add(two) + await two.one.add(one) + assert await one.two == [two] + assert await two.one == [one] + + +@pytest.mark.asyncio +async def test__add__two(db): + """Test adding multiple related objects at once.""" + one = await testmodels.M2MOne.create(name="One") + two1 = await testmodels.M2MTwo.create(name="Two") + two2 = await testmodels.M2MTwo.create(name="Two") + await one.two.add(two1, two2) + assert await one.two == [two1, two2] + assert await two1.one == [one] + assert await two2.one == [one] + + +@pytest.mark.asyncio +async def test__remove(db): + """Test removing one related object from M2M relation.""" + one = await testmodels.M2MOne.create(name="One") + two1 = await testmodels.M2MTwo.create(name="Two") + two2 = await testmodels.M2MTwo.create(name="Two") + await one.two.add(two1, two2) + await one.two.remove(two1) + assert await one.two == [two2] + assert await two1.one == [] + assert await two2.one == [one] + + +@pytest.mark.asyncio +async def test__remove__many(db): + """Test removing multiple related objects at once.""" + one = await testmodels.M2MOne.create(name="One") + two1 = await testmodels.M2MTwo.create(name="Two1") + two2 = await testmodels.M2MTwo.create(name="Two2") + two3 = await testmodels.M2MTwo.create(name="Two3") + await one.two.add(two1, two2, two3) + await one.two.remove(two1, two2) + assert await one.two == [two3] + assert await two1.one == [] + assert await two2.one == [] + assert await two3.one == [one] + + +@pytest.mark.asyncio +async def test__remove__blank(db): + """Test that removing nothing raises OperationalError.""" + one = await testmodels.M2MOne.create(name="One") + with pytest.raises(OperationalError, match=r"remove\(\) called on no instances"): + await one.two.remove() + + +@pytest.mark.asyncio +async def test__clear(db): + """Test clearing all related objects from M2M relation.""" + one = await testmodels.M2MOne.create(name="One") + two1 = await testmodels.M2MTwo.create(name="Two") + two2 = await testmodels.M2MTwo.create(name="Two") + await one.two.add(two1, two2) + await one.two.clear() + assert await one.two == [] + assert await two1.one == [] + assert await two2.one == [] + + +@pytest.mark.asyncio +async def test__uninstantiated_add(db): + """Test that adding to unsaved model raises OperationalError.""" + one = testmodels.M2MOne(name="One") + two = await testmodels.M2MTwo.create(name="Two") + with pytest.raises(OperationalError, match=r"You should first call .save\(\) on "): await one.two.add(two) + + +@pytest.mark.asyncio +async def test__add_uninstantiated(db): + """Test that adding unsaved model raises OperationalError.""" + one = testmodels.M2MOne(name="One") + two = await testmodels.M2MTwo.create(name="Two") + with pytest.raises(OperationalError, match=r"You should first call .save\(\) on "): await two.one.add(one) - self.assertEqual(await one.two, [two]) - self.assertEqual(await two.one, [one]) - - async def test__add__two(self): - one = await testmodels.M2MOne.create(name="One") - two1 = await testmodels.M2MTwo.create(name="Two") - two2 = await testmodels.M2MTwo.create(name="Two") - await one.two.add(two1, two2) - self.assertEqual(await one.two, [two1, two2]) - self.assertEqual(await two1.one, [one]) - self.assertEqual(await two2.one, [one]) - - async def test__remove(self): - one = await testmodels.M2MOne.create(name="One") - two1 = await testmodels.M2MTwo.create(name="Two") - two2 = await testmodels.M2MTwo.create(name="Two") - await one.two.add(two1, two2) - await one.two.remove(two1) - self.assertEqual(await one.two, [two2]) - self.assertEqual(await two1.one, []) - self.assertEqual(await two2.one, [one]) - - async def test__remove__many(self): - one = await testmodels.M2MOne.create(name="One") - two1 = await testmodels.M2MTwo.create(name="Two1") - two2 = await testmodels.M2MTwo.create(name="Two2") - two3 = await testmodels.M2MTwo.create(name="Two3") - await one.two.add(two1, two2, two3) - await one.two.remove(two1, two2) - self.assertEqual(await one.two, [two3]) - self.assertEqual(await two1.one, []) - self.assertEqual(await two2.one, []) - self.assertEqual(await two3.one, [one]) - - async def test__remove__blank(self): - one = await testmodels.M2MOne.create(name="One") - with self.assertRaisesRegex(OperationalError, r"remove\(\) called on no instances"): - await one.two.remove() - - async def test__clear(self): - one = await testmodels.M2MOne.create(name="One") - two1 = await testmodels.M2MTwo.create(name="Two") - two2 = await testmodels.M2MTwo.create(name="Two") - await one.two.add(two1, two2) - await one.two.clear() - self.assertEqual(await one.two, []) - self.assertEqual(await two1.one, []) - self.assertEqual(await two2.one, []) - - async def test__uninstantiated_add(self): - one = testmodels.M2MOne(name="One") - two = await testmodels.M2MTwo.create(name="Two") - with self.assertRaisesRegex( - OperationalError, r"You should first call .save\(\) on " - ): - await one.two.add(two) - - async def test__add_uninstantiated(self): - one = testmodels.M2MOne(name="One") - two = await testmodels.M2MTwo.create(name="Two") - with self.assertRaisesRegex( - OperationalError, r"You should first call .save\(\) on " - ): - await two.one.add(one) - - async def test_create_unique_index(self): - message = "Parameter `create_unique_index` is deprecated! Use `unique` instead." - with self.assertWarnsRegex(DeprecationWarning, message): - field = ManyToManyField("models.Foo", create_unique_index=False) - assert field.unique is False - with self.assertWarnsRegex(DeprecationWarning, message): - field = ManyToManyField("models.Foo", create_unique_index=False, unique=True) - assert field.unique is False - with self.assertWarnsRegex(DeprecationWarning, message): - field = ManyToManyField("models.Foo", create_unique_index=True) - assert field.unique is True - with self.assertWarnsRegex(DeprecationWarning, message): - field = ManyToManyField("models.Foo", create_unique_index=True, unique=False) - assert field.unique is True - field = ManyToManyField( - "models.Group", - ) - assert field.unique is True - field = ManyToManyField( - "models.Group", - "user_group", - "user_id", - "group_id", - "users", - "CASCADE", - True, - False, - ) - assert field.unique is False + + +@pytest.mark.asyncio +async def test_create_unique_index(db): + """Test deprecated create_unique_index parameter behavior.""" + message = "Parameter `create_unique_index` is deprecated! Use `unique` instead." + with pytest.warns(DeprecationWarning, match=message): + field = ManyToManyField("models.Foo", create_unique_index=False) + assert field.unique is False + with pytest.warns(DeprecationWarning, match=message): + field = ManyToManyField("models.Foo", create_unique_index=False, unique=True) + assert field.unique is False + with pytest.warns(DeprecationWarning, match=message): + field = ManyToManyField("models.Foo", create_unique_index=True) + assert field.unique is True + with pytest.warns(DeprecationWarning, match=message): + field = ManyToManyField("models.Foo", create_unique_index=True, unique=False) + assert field.unique is True + field = ManyToManyField( + "models.Group", + ) + assert field.unique is True + field = ManyToManyField( + "models.Group", + "user_group", + "user_id", + "group_id", + "users", + "CASCADE", + True, + False, + ) + assert field.unique is False diff --git a/tests/fields/test_m2m_uuid.py b/tests/fields/test_m2m_uuid.py index 9b8a3f1e1..7648750b2 100644 --- a/tests/fields/test_m2m_uuid.py +++ b/tests/fields/test_m2m_uuid.py @@ -1,115 +1,165 @@ +import pytest + from tests import testmodels -from tortoise.contrib import test from tortoise.exceptions import OperationalError -class TestManyToManyUUIDField(test.TestCase): - UUIDPkModel = testmodels.UUIDPkModel - UUIDM2MRelatedModel = testmodels.UUIDM2MRelatedModel - - async def test_empty(self): - await self.UUIDM2MRelatedModel.create() - - async def test__add(self): - one = await self.UUIDM2MRelatedModel.create() - two = await self.UUIDPkModel.create() - await one.models.add(two) - self.assertEqual(await one.models, [two]) - self.assertEqual(await two.peers, [one]) - - async def test__add__nothing(self): - one = await self.UUIDPkModel.create() - await one.peers.add() +# Parameterize to test both standard and source-field models +@pytest.fixture( + params=[ + pytest.param( + (testmodels.UUIDPkModel, testmodels.UUIDM2MRelatedModel), + id="standard", + ), + pytest.param( + (testmodels.UUIDPkSourceModel, testmodels.UUIDM2MRelatedSourceModel), + id="sourced", + ), + ] +) +def m2m_uuid_models(request): + """ + Fixture providing UUID model classes for M2M tests. + + Tests both standard UUID models and source-field models with customized DB names. + """ + return request.param + + +@pytest.mark.asyncio +async def test_empty(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + await UUIDM2MRelatedModel.create() + + +@pytest.mark.asyncio +async def test__add(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one = await UUIDM2MRelatedModel.create() + two = await UUIDPkModel.create() + await one.models.add(two) + assert await one.models == [two] + assert await two.peers == [one] + + +@pytest.mark.asyncio +async def test__add__nothing(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one = await UUIDPkModel.create() + await one.peers.add() + + +@pytest.mark.asyncio +async def test__add__reverse(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one = await UUIDM2MRelatedModel.create() + two = await UUIDPkModel.create() + await two.peers.add(one) + assert await one.models == [two] + assert await two.peers == [one] + + +@pytest.mark.asyncio +async def test__add__many(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one = await UUIDPkModel.create() + two = await UUIDM2MRelatedModel.create() + await one.peers.add(two) + await one.peers.add(two) + await two.models.add(one) + assert await one.peers == [two] + assert await two.models == [one] + + +@pytest.mark.asyncio +async def test__add__two(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one = await UUIDPkModel.create() + two1 = await UUIDM2MRelatedModel.create() + two2 = await UUIDM2MRelatedModel.create() + await one.peers.add(two1, two2) + assert set(await one.peers) == {two1, two2} + assert await two1.models == [one] + assert await two2.models == [one] + + +@pytest.mark.asyncio +async def test__add__two_two(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one1 = await UUIDPkModel.create() + one2 = await UUIDPkModel.create() + two1 = await UUIDM2MRelatedModel.create() + two2 = await UUIDM2MRelatedModel.create() + await one1.peers.add(two1, two2) + await one2.peers.add(two1, two2) + assert set(await one1.peers) == {two1, two2} + assert set(await one2.peers) == {two1, two2} + assert set(await two1.models) == {one1, one2} + assert set(await two2.models) == {one1, one2} + + +@pytest.mark.asyncio +async def test__remove(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one = await UUIDPkModel.create() + two1 = await UUIDM2MRelatedModel.create() + two2 = await UUIDM2MRelatedModel.create() + await one.peers.add(two1, two2) + await one.peers.remove(two1) + assert await one.peers == [two2] + assert await two1.models == [] + assert await two2.models == [one] + + +@pytest.mark.asyncio +async def test__remove__many(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one = await UUIDPkModel.create() + two1 = await UUIDM2MRelatedModel.create() + two2 = await UUIDM2MRelatedModel.create() + two3 = await UUIDM2MRelatedModel.create() + await one.peers.add(two1, two2, two3) + await one.peers.remove(two1, two2) + assert await one.peers == [two3] + assert await two1.models == [] + assert await two2.models == [] + assert await two3.models == [one] + + +@pytest.mark.asyncio +async def test__remove__blank(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one = await UUIDPkModel.create() + with pytest.raises(OperationalError, match=r"remove\(\) called on no instances"): + await one.peers.remove() + + +@pytest.mark.asyncio +async def test__clear(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one = await UUIDPkModel.create() + two1 = await UUIDM2MRelatedModel.create() + two2 = await UUIDM2MRelatedModel.create() + await one.peers.add(two1, two2) + await one.peers.clear() + assert await one.peers == [] + assert await two1.models == [] + assert await two2.models == [] + + +@pytest.mark.asyncio +async def test__uninstantiated_add(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one = UUIDPkModel() + two = await UUIDM2MRelatedModel.create() + with pytest.raises(OperationalError, match=r"You should first call .save\(\) on"): + await one.peers.add(two) - async def test__add__reverse(self): - one = await self.UUIDM2MRelatedModel.create() - two = await self.UUIDPkModel.create() - await two.peers.add(one) - self.assertEqual(await one.models, [two]) - self.assertEqual(await two.peers, [one]) - async def test__add__many(self): - one = await self.UUIDPkModel.create() - two = await self.UUIDM2MRelatedModel.create() - await one.peers.add(two) - await one.peers.add(two) +@pytest.mark.asyncio +async def test__add_uninstantiated(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one = UUIDPkModel() + two = await UUIDM2MRelatedModel.create() + with pytest.raises(OperationalError, match=r"You should first call .save\(\) on"): await two.models.add(one) - self.assertEqual(await one.peers, [two]) - self.assertEqual(await two.models, [one]) - - async def test__add__two(self): - one = await self.UUIDPkModel.create() - two1 = await self.UUIDM2MRelatedModel.create() - two2 = await self.UUIDM2MRelatedModel.create() - await one.peers.add(two1, two2) - self.assertEqual(set(await one.peers), {two1, two2}) - self.assertEqual(await two1.models, [one]) - self.assertEqual(await two2.models, [one]) - - async def test__add__two_two(self): - one1 = await self.UUIDPkModel.create() - one2 = await self.UUIDPkModel.create() - two1 = await self.UUIDM2MRelatedModel.create() - two2 = await self.UUIDM2MRelatedModel.create() - await one1.peers.add(two1, two2) - await one2.peers.add(two1, two2) - self.assertEqual(set(await one1.peers), {two1, two2}) - self.assertEqual(set(await one2.peers), {two1, two2}) - self.assertEqual(set(await two1.models), {one1, one2}) - self.assertEqual(set(await two2.models), {one1, one2}) - - async def test__remove(self): - one = await self.UUIDPkModel.create() - two1 = await self.UUIDM2MRelatedModel.create() - two2 = await self.UUIDM2MRelatedModel.create() - await one.peers.add(two1, two2) - await one.peers.remove(two1) - self.assertEqual(await one.peers, [two2]) - self.assertEqual(await two1.models, []) - self.assertEqual(await two2.models, [one]) - - async def test__remove__many(self): - one = await self.UUIDPkModel.create() - two1 = await self.UUIDM2MRelatedModel.create() - two2 = await self.UUIDM2MRelatedModel.create() - two3 = await self.UUIDM2MRelatedModel.create() - await one.peers.add(two1, two2, two3) - await one.peers.remove(two1, two2) - self.assertEqual(await one.peers, [two3]) - self.assertEqual(await two1.models, []) - self.assertEqual(await two2.models, []) - self.assertEqual(await two3.models, [one]) - - async def test__remove__blank(self): - one = await self.UUIDPkModel.create() - with self.assertRaisesRegex(OperationalError, r"remove\(\) called on no instances"): - await one.peers.remove() - - async def test__clear(self): - one = await self.UUIDPkModel.create() - two1 = await self.UUIDM2MRelatedModel.create() - two2 = await self.UUIDM2MRelatedModel.create() - await one.peers.add(two1, two2) - await one.peers.clear() - self.assertEqual(await one.peers, []) - self.assertEqual(await two1.models, []) - self.assertEqual(await two2.models, []) - - async def test__uninstantiated_add(self): - one = self.UUIDPkModel() - two = await self.UUIDM2MRelatedModel.create() - with self.assertRaisesRegex(OperationalError, r"You should first call .save\(\) on"): - await one.peers.add(two) - - async def test__add_uninstantiated(self): - one = self.UUIDPkModel() - two = await self.UUIDM2MRelatedModel.create() - with self.assertRaisesRegex(OperationalError, r"You should first call .save\(\) on"): - await two.models.add(one) - - # TODO: Sorting? - - -class TestManyToManyUUIDSourceField(TestManyToManyUUIDField): - UUIDPkModel = testmodels.UUIDPkSourceModel # type: ignore - UUIDM2MRelatedModel = testmodels.UUIDM2MRelatedSourceModel # type: ignore diff --git a/tests/fields/test_o2o_with_unique.py b/tests/fields/test_o2o_with_unique.py index ff8cda809..46f10c912 100644 --- a/tests/fields/test_o2o_with_unique.py +++ b/tests/fields/test_o2o_with_unique.py @@ -1,104 +1,131 @@ +import pytest + from tests import testmodels -from tortoise.contrib import test from tortoise.exceptions import IntegrityError, OperationalError from tortoise.queryset import QuerySet -class TestOneToOneFieldWithUnique(test.TestCase): - async def test_principal__empty(self): - with self.assertRaises(IntegrityError): - await testmodels.Principal.create() - - async def test_principal__create_by_id(self): - school = await testmodels.School.create(id=1024, name="School1") - principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school_id=school.id) - self.assertEqual(principal.school_id, school.id) - self.assertEqual(await school.principal, principal) - - async def test_principal__create_by_name(self): - school = await testmodels.School.create(id=1024, name="School1") - principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) - await principal.fetch_related("school") - self.assertEqual(principal.school, school) - self.assertEqual(await school.principal, principal) - - async def test_principal__by_name__created_prefetched(self): - school = await testmodels.School.create(id=1024, name="School1") - principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) - self.assertEqual(principal.school, school) - self.assertEqual(await school.principal, principal) - - async def test_principal__by_name__unfetched(self): - school = await testmodels.School.create(id=1024, name="School1") - principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) - principal = await testmodels.Principal.get(id=principal.id) - self.assertIsInstance(principal.school, QuerySet) - - async def test_principal__by_name__re_awaited(self): - school = await testmodels.School.create(id=1024, name="School1") - principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) - await principal.fetch_related("school") - self.assertEqual(principal.school, school) - self.assertEqual(await principal.school, school) - - async def test_principal__by_name__awaited(self): - school = await testmodels.School.create(id=1024, name="School1") - principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) - principal = await testmodels.Principal.get(id=principal.id) - self.assertEqual(await principal.school, school) - self.assertEqual(await school.principal, principal) - - async def test_update_by_name(self): - school = await testmodels.School.create(id=1024, name="School1") - school2 = await testmodels.School.create(id=2048, name="School2") - principal0 = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) - - await testmodels.Principal.filter(id=principal0.id).update(school=school2) - principal = await testmodels.Principal.get(id=principal0.id) - - await principal.fetch_related("school") - self.assertEqual(principal.school, school2) - self.assertEqual(await school.principal, None) - self.assertEqual(await school2.principal, principal) - - async def test_update_by_id(self): - school = await testmodels.School.create(id=1024, name="School1") - school2 = await testmodels.School.create(id=2048, name="School2") - principal0 = await testmodels.Principal.create(name="Sang-Heon Jeon", school_id=school.id) - - await testmodels.Principal.filter(id=principal0.id).update(school_id=school2.id) - principal = await testmodels.Principal.get(id=principal0.id) - - self.assertEqual(principal.school_id, school2.id) - self.assertEqual(await school.principal, None) - self.assertEqual(await school2.principal, principal) - - async def test_delete_by_name(self): - school = await testmodels.School.create(id=1024, name="School1") - principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) - del principal.school - with self.assertRaises(IntegrityError): - await principal.save() - - async def test_principal__uninstantiated_create(self): - school = await testmodels.School(id=1024, name="School1") - with self.assertRaisesRegex(OperationalError, "You should first call .save()"): - await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) - - async def test_principal__instantiated_create(self): - school = await testmodels.School.create(id=1024, name="School1") - await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) +@pytest.mark.asyncio +async def test_principal__empty(db): + with pytest.raises(IntegrityError): + await testmodels.Principal.create() + + +@pytest.mark.asyncio +async def test_principal__create_by_id(db): + school = await testmodels.School.create(id=1024, name="School1") + principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school_id=school.id) + assert principal.school_id == school.id + assert await school.principal == principal + + +@pytest.mark.asyncio +async def test_principal__create_by_name(db): + school = await testmodels.School.create(id=1024, name="School1") + principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) + await principal.fetch_related("school") + assert principal.school == school + assert await school.principal == principal + + +@pytest.mark.asyncio +async def test_principal__by_name__created_prefetched(db): + school = await testmodels.School.create(id=1024, name="School1") + principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) + assert principal.school == school + assert await school.principal == principal + + +@pytest.mark.asyncio +async def test_principal__by_name__unfetched(db): + school = await testmodels.School.create(id=1024, name="School1") + principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) + principal = await testmodels.Principal.get(id=principal.id) + assert isinstance(principal.school, QuerySet) + + +@pytest.mark.asyncio +async def test_principal__by_name__re_awaited(db): + school = await testmodels.School.create(id=1024, name="School1") + principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) + await principal.fetch_related("school") + assert principal.school == school + assert await principal.school == school + + +@pytest.mark.asyncio +async def test_principal__by_name__awaited(db): + school = await testmodels.School.create(id=1024, name="School1") + principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) + principal = await testmodels.Principal.get(id=principal.id) + assert await principal.school == school + assert await school.principal == principal + + +@pytest.mark.asyncio +async def test_update_by_name(db): + school = await testmodels.School.create(id=1024, name="School1") + school2 = await testmodels.School.create(id=2048, name="School2") + principal0 = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) - async def test_principal__fetched_bool(self): - school = await testmodels.School.create(id=1024, name="School1") - await school.fetch_related("principal") - self.assertFalse(bool(school.principal)) + await testmodels.Principal.filter(id=principal0.id).update(school=school2) + principal = await testmodels.Principal.get(id=principal0.id) + + await principal.fetch_related("school") + assert principal.school == school2 + assert await school.principal is None + assert await school2.principal == principal + + +@pytest.mark.asyncio +async def test_update_by_id(db): + school = await testmodels.School.create(id=1024, name="School1") + school2 = await testmodels.School.create(id=2048, name="School2") + principal0 = await testmodels.Principal.create(name="Sang-Heon Jeon", school_id=school.id) + + await testmodels.Principal.filter(id=principal0.id).update(school_id=school2.id) + principal = await testmodels.Principal.get(id=principal0.id) + + assert principal.school_id == school2.id + assert await school.principal is None + assert await school2.principal == principal + + +@pytest.mark.asyncio +async def test_delete_by_name(db): + school = await testmodels.School.create(id=1024, name="School1") + principal = await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) + del principal.school + with pytest.raises(IntegrityError): + await principal.save() + + +@pytest.mark.asyncio +async def test_principal__uninstantiated_create(db): + school = await testmodels.School(id=1024, name="School1") + with pytest.raises(OperationalError, match="You should first call .save()"): await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) - await school.fetch_related("principal") - self.assertTrue(bool(school.principal)) - - async def test_principal__filter(self): - school = await testmodels.School.create(id=1024, name="School1") - principal = await testmodels.Principal.create(name="Sang-Heon Jeon1", school=school) - self.assertEqual(await school.principal.filter(name="Sang-Heon Jeon1"), principal) - self.assertEqual(await school.principal.filter(name="Sang-Heon Jeon2"), None) + + +@pytest.mark.asyncio +async def test_principal__instantiated_create(db): + school = await testmodels.School.create(id=1024, name="School1") + await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) + + +@pytest.mark.asyncio +async def test_principal__fetched_bool(db): + school = await testmodels.School.create(id=1024, name="School1") + await school.fetch_related("principal") + assert not bool(school.principal) + await testmodels.Principal.create(name="Sang-Heon Jeon", school=school) + await school.fetch_related("principal") + assert bool(school.principal) + + +@pytest.mark.asyncio +async def test_principal__filter(db): + school = await testmodels.School.create(id=1024, name="School1") + principal = await testmodels.Principal.create(name="Sang-Heon Jeon1", school=school) + assert await school.principal.filter(name="Sang-Heon Jeon1") == principal + assert await school.principal.filter(name="Sang-Heon Jeon2") is None diff --git a/tests/fields/test_subclass.py b/tests/fields/test_subclass.py index 587682a75..cbfce18f8 100644 --- a/tests/fields/test_subclass.py +++ b/tests/fields/test_subclass.py @@ -1,13 +1,15 @@ +import pytest + from tests.fields.subclass_models import ( Contact, ContactTypeEnum, RaceParticipant, RacePlacingEnum, ) -from tortoise.contrib import test async def create_participants(): + """Helper to create race participants for tests.""" test1 = await RaceParticipant.create( first_name="Alex", place=RacePlacingEnum.FIRST, @@ -24,66 +26,78 @@ async def create_participants(): return test1, test2, test3, test4 -class TestEnumField(test.IsolatedTestCase): - """Tests the enumeration field.""" +@pytest.mark.asyncio +async def test_enum_field_create(db_subclass_fields): + """Asserts that the new field is saved properly.""" + test1, _, _, _ = await create_participants() + assert test1 in await RaceParticipant.all() + assert test1.place == RacePlacingEnum.FIRST - tortoise_test_modules = ["tests.fields.subclass_models"] - async def test_enum_field_create(self): - """Asserts that the new field is saved properly.""" - test1, _, _, _ = await create_participants() - self.assertIn(test1, await RaceParticipant.all()) - self.assertEqual(test1.place, RacePlacingEnum.FIRST) +@pytest.mark.asyncio +async def test_enum_field_update(db_subclass_fields): + """Asserts that the new field can be updated correctly.""" + test1, _, _, _ = await create_participants() + test1.place = RacePlacingEnum.SECOND + await test1.save() - async def test_enum_field_update(self): - """Asserts that the new field can be updated correctly.""" - test1, _, _, _ = await create_participants() - test1.place = RacePlacingEnum.SECOND - await test1.save() + tied_second = await RaceParticipant.filter(place=RacePlacingEnum.SECOND) - tied_second = await RaceParticipant.filter(place=RacePlacingEnum.SECOND) + assert test1 in tied_second + assert len(tied_second) == 2 - self.assertIn(test1, tied_second) - self.assertEqual(len(tied_second), 2) - async def test_enum_field_filter(self): - """Assert that filters correctly select the enums.""" - await create_participants() +@pytest.mark.asyncio +async def test_enum_field_filter(db_subclass_fields): + """Assert that filters correctly select the enums.""" + await create_participants() - first_place = await RaceParticipant.filter(place=RacePlacingEnum.FIRST).first() + first_place = await RaceParticipant.filter(place=RacePlacingEnum.FIRST).first() + second_place = await RaceParticipant.filter(place=RacePlacingEnum.SECOND).first() - second_place = await RaceParticipant.filter(place=RacePlacingEnum.SECOND).first() + assert first_place.place == RacePlacingEnum.FIRST + assert second_place.place == RacePlacingEnum.SECOND - self.assertEqual(first_place.place, RacePlacingEnum.FIRST) - self.assertEqual(second_place.place, RacePlacingEnum.SECOND) - async def test_enum_field_delete(self): - """Assert that delete correctly removes the right participant by their place.""" - await create_participants() - await RaceParticipant.filter(place=RacePlacingEnum.FIRST).delete() - self.assertEqual(await RaceParticipant.all().count(), 3) +@pytest.mark.asyncio +async def test_enum_field_delete(db_subclass_fields): + """Assert that delete correctly removes the right participant by their place.""" + await create_participants() + await RaceParticipant.filter(place=RacePlacingEnum.FIRST).delete() + assert await RaceParticipant.all().count() == 3 - async def test_enum_field_default(self): - _, _, _, test4 = await create_participants() - self.assertEqual(test4.place, RacePlacingEnum.DNF) - async def test_enum_field_null(self): - """Assert that filtering by None selects the records which are null.""" - _, _, test3, test4 = await create_participants() +@pytest.mark.asyncio +async def test_enum_field_default(db_subclass_fields): + """Test that default enum value is applied correctly.""" + _, _, _, test4 = await create_participants() + assert test4.place == RacePlacingEnum.DNF - no_predictions = await RaceParticipant.filter(predicted_place__isnull=True) - self.assertIn(test3, no_predictions) - self.assertIn(test4, no_predictions) +@pytest.mark.asyncio +async def test_enum_field_null(db_subclass_fields): + """Assert that filtering by None selects the records which are null.""" + _, _, test3, test4 = await create_participants() + + no_predictions = await RaceParticipant.filter(predicted_place__isnull=True) + + assert test3 in no_predictions + assert test4 in no_predictions - async def test_update_with_int_enum_value(self): - contact = await Contact.create() - contact.type = ContactTypeEnum.home - await contact.save() - async def test_exception_on_invalid_data_type_in_int_field(self): - contact = await Contact.create() +@pytest.mark.asyncio +async def test_update_with_int_enum_value(db_subclass_fields): + """Test updating with integer enum value.""" + contact = await Contact.create() + contact.type = ContactTypeEnum.home + await contact.save() - contact.type = "not_int" - with self.assertRaises((TypeError, ValueError)): - await contact.save() + +@pytest.mark.asyncio +async def test_exception_on_invalid_data_type_in_int_field(db_subclass_fields): + """Test that invalid data types raise appropriate exceptions.""" + contact = await Contact.create() + + contact.type = "not_int" + with pytest.raises((TypeError, ValueError)): + await contact.save() diff --git a/tests/fields/test_subclass_filters.py b/tests/fields/test_subclass_filters.py index d4b63e77e..545934f31 100644 --- a/tests/fields/test_subclass_filters.py +++ b/tests/fields/test_subclass_filters.py @@ -1,99 +1,92 @@ +import pytest +import pytest_asyncio + from tests.fields.subclass_models import RaceParticipant, RacePlacingEnum -from tortoise.contrib import test -class TestCustomFieldFilters(test.IsolatedTestCase): - tortoise_test_modules = ["tests.fields.subclass_models"] +@pytest_asyncio.fixture +async def race_data(db_subclass_fields): + """Set up race participant data for filter tests.""" + await RaceParticipant.create( + first_name="George", place=RacePlacingEnum.FIRST, predicted_place=RacePlacingEnum.SECOND + ) + await RaceParticipant.create( + first_name="John", place=RacePlacingEnum.SECOND, predicted_place=RacePlacingEnum.THIRD + ) + await RaceParticipant.create(first_name="Paul", place=RacePlacingEnum.THIRD) + await RaceParticipant.create(first_name="Ringo", place=RacePlacingEnum.RUNNER_UP) + await RaceParticipant.create(first_name="Stuart", predicted_place=RacePlacingEnum.FIRST) + yield db_subclass_fields - async def asyncSetUp(self): - await super().asyncSetUp() - await RaceParticipant.create( - first_name="George", place=RacePlacingEnum.FIRST, predicted_place=RacePlacingEnum.SECOND - ) - await RaceParticipant.create( - first_name="John", place=RacePlacingEnum.SECOND, predicted_place=RacePlacingEnum.THIRD - ) - await RaceParticipant.create(first_name="Paul", place=RacePlacingEnum.THIRD) - await RaceParticipant.create(first_name="Ringo", place=RacePlacingEnum.RUNNER_UP) - await RaceParticipant.create(first_name="Stuart", predicted_place=RacePlacingEnum.FIRST) - async def test_equal(self): - self.assertEqual( - set( - await RaceParticipant.filter(place=RacePlacingEnum.FIRST).values_list( - "place", flat=True - ) - ), - {RacePlacingEnum.FIRST}, - ) +@pytest.mark.asyncio +async def test_equal(race_data): + """Test equal filter on custom enum field.""" + assert set( + await RaceParticipant.filter(place=RacePlacingEnum.FIRST).values_list("place", flat=True) + ) == {RacePlacingEnum.FIRST} - async def test_not(self): - self.assertEqual( - set( - await RaceParticipant.filter(place__not=RacePlacingEnum.FIRST).values_list( - "place", flat=True - ) - ), - { - RacePlacingEnum.SECOND, - RacePlacingEnum.THIRD, - RacePlacingEnum.RUNNER_UP, - RacePlacingEnum.DNF, - }, - ) - async def test_in(self): - self.assertSetEqual( - set( - await RaceParticipant.filter( - place__in=[RacePlacingEnum.DNF, RacePlacingEnum.RUNNER_UP] - ).values_list("place", flat=True) - ), - {RacePlacingEnum.DNF, RacePlacingEnum.RUNNER_UP}, +@pytest.mark.asyncio +async def test_not(race_data): + """Test not filter on custom enum field.""" + assert set( + await RaceParticipant.filter(place__not=RacePlacingEnum.FIRST).values_list( + "place", flat=True ) + ) == { + RacePlacingEnum.SECOND, + RacePlacingEnum.THIRD, + RacePlacingEnum.RUNNER_UP, + RacePlacingEnum.DNF, + } - async def test_not_in(self): - self.assertSetEqual( - set( - await RaceParticipant.filter( - place__not_in=[RacePlacingEnum.DNF, RacePlacingEnum.RUNNER_UP] - ).values_list("place", flat=True) - ), - {RacePlacingEnum.FIRST, RacePlacingEnum.SECOND, RacePlacingEnum.THIRD}, - ) - async def test_isnull(self): - self.assertSetEqual( - set( - await RaceParticipant.filter(predicted_place__isnull=True).values_list( - "first_name", flat=True - ) - ), - {"Paul", "Ringo"}, +@pytest.mark.asyncio +async def test_in(race_data): + """Test in filter on custom enum field.""" + assert set( + await RaceParticipant.filter( + place__in=[RacePlacingEnum.DNF, RacePlacingEnum.RUNNER_UP] + ).values_list("place", flat=True) + ) == {RacePlacingEnum.DNF, RacePlacingEnum.RUNNER_UP} + + +@pytest.mark.asyncio +async def test_not_in(race_data): + """Test not_in filter on custom enum field.""" + assert set( + await RaceParticipant.filter( + place__not_in=[RacePlacingEnum.DNF, RacePlacingEnum.RUNNER_UP] + ).values_list("place", flat=True) + ) == {RacePlacingEnum.FIRST, RacePlacingEnum.SECOND, RacePlacingEnum.THIRD} + + +@pytest.mark.asyncio +async def test_isnull(race_data): + """Test isnull filter on custom enum field.""" + assert set( + await RaceParticipant.filter(predicted_place__isnull=True).values_list( + "first_name", flat=True ) - self.assertSetEqual( - set( - await RaceParticipant.filter(predicted_place__isnull=False).values_list( - "first_name", flat=True - ) - ), - {"George", "John", "Stuart"}, + ) == {"Paul", "Ringo"} + assert set( + await RaceParticipant.filter(predicted_place__isnull=False).values_list( + "first_name", flat=True ) + ) == {"George", "John", "Stuart"} + - async def test_not_isnull(self): - self.assertSetEqual( - set( - await RaceParticipant.filter(predicted_place__not_isnull=False).values_list( - "first_name", flat=True - ) - ), - {"Paul", "Ringo"}, +@pytest.mark.asyncio +async def test_not_isnull(race_data): + """Test not_isnull filter on custom enum field.""" + assert set( + await RaceParticipant.filter(predicted_place__not_isnull=False).values_list( + "first_name", flat=True ) - self.assertSetEqual( - set( - await RaceParticipant.filter(predicted_place__not_isnull=True).values_list( - "first_name", flat=True - ) - ), - {"George", "John", "Stuart"}, + ) == {"Paul", "Ringo"} + assert set( + await RaceParticipant.filter(predicted_place__not_isnull=True).values_list( + "first_name", flat=True ) + ) == {"George", "John", "Stuart"} diff --git a/tests/fields/test_text.py b/tests/fields/test_text.py index 4e2885261..1d8b59389 100644 --- a/tests/fields/test_text.py +++ b/tests/fields/test_text.py @@ -1,49 +1,59 @@ +import pytest + from tests import testmodels -from tortoise.contrib import test from tortoise.exceptions import ConfigurationError, IntegrityError from tortoise.fields import TextField -class TestTextFields(test.TestCase): - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels.TextFields.create() - - async def test_create(self): - obj0 = await testmodels.TextFields.create(text="baaa" * 32000) - obj = await testmodels.TextFields.get(id=obj0.id) - self.assertEqual(obj.text, "baaa" * 32000) - self.assertEqual(obj.text_null, None) - await obj.save() - obj2 = await testmodels.TextFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_values(self): - obj0 = await testmodels.TextFields.create(text="baa") - values = await testmodels.TextFields.get(id=obj0.id).values("text") - self.assertEqual(values["text"], "baa") - - async def test_values_list(self): - obj0 = await testmodels.TextFields.create(text="baa") - values = await testmodels.TextFields.get(id=obj0.id).values_list("text", flat=True) - self.assertEqual(values, "baa") - - def test_unique_fail(self): - msg = "TextField can't be indexed, consider CharField" - with self.assertRaisesRegex(ConfigurationError, msg): - with self.assertWarnsRegex( - DeprecationWarning, "`index` is deprecated, please use `db_index` instead" - ): - TextField(index=True) - with self.assertRaisesRegex(ConfigurationError, msg): - TextField(db_index=True) - - def test_index_fail(self): - with self.assertRaisesRegex(ConfigurationError, "can't be indexed, consider CharField"): - TextField(index=True) +@pytest.mark.asyncio +async def test_empty(db): + with pytest.raises(IntegrityError): + await testmodels.TextFields.create() + + +@pytest.mark.asyncio +async def test_create(db): + obj0 = await testmodels.TextFields.create(text="baaa" * 32000) + obj = await testmodels.TextFields.get(id=obj0.id) + assert obj.text == "baaa" * 32000 + assert obj.text_null is None + await obj.save() + obj2 = await testmodels.TextFields.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_values(db): + obj0 = await testmodels.TextFields.create(text="baa") + values = await testmodels.TextFields.get(id=obj0.id).values("text") + assert values["text"] == "baa" + - def test_pk_deprecated(self): - with self.assertWarnsRegex( - DeprecationWarning, "TextField as a PrimaryKey is Deprecated, use CharField" +@pytest.mark.asyncio +async def test_values_list(db): + obj0 = await testmodels.TextFields.create(text="baa") + values = await testmodels.TextFields.get(id=obj0.id).values_list("text", flat=True) + assert values == "baa" + + +def test_unique_fail(): + msg = "TextField can't be indexed, consider CharField" + with pytest.raises(ConfigurationError, match=msg): + with pytest.warns( + DeprecationWarning, match="`index` is deprecated, please use `db_index` instead" ): - TextField(primary_key=True) + TextField(index=True) + with pytest.raises(ConfigurationError, match=msg): + TextField(db_index=True) + + +def test_index_fail(): + with pytest.raises(ConfigurationError, match="can't be indexed, consider CharField"): + TextField(index=True) + + +def test_pk_deprecated(): + with pytest.warns( + DeprecationWarning, match="TextField as a PrimaryKey is Deprecated, use CharField" + ): + TextField(primary_key=True) diff --git a/tests/fields/test_time.py b/tests/fields/test_time.py index 8a6c2af1f..bb5304d09 100644 --- a/tests/fields/test_time.py +++ b/tests/fields/test_time.py @@ -4,349 +4,503 @@ from time import sleep from unittest.mock import patch +import pytest import pytz from iso8601 import ParseError from tests import testmodels -from tortoise import Model, fields, timezone +from tortoise import fields, timezone from tortoise.contrib import test from tortoise.contrib.test.condition import NotIn from tortoise.exceptions import ConfigurationError, IntegrityError from tortoise.expressions import F from tortoise.timezone import get_default_timezone +# ============================================================================ +# TestEmpty -> test_empty_* +# ============================================================================ + + +@pytest.mark.asyncio +async def test_empty_datetime_fields(db): + """Test that creating DatetimeFields without required field raises IntegrityError.""" + with pytest.raises(IntegrityError): + await testmodels.DatetimeFields.create() + + +# ============================================================================ +# TestDatetimeFields -> test_datetime_* +# ============================================================================ + + +@pytest.fixture(autouse=True) +def reset_timezone_cache(): + """Reset timezone cache before and after each test.""" + timezone._reset_timezone_cache() + yield + timezone._reset_timezone_cache() + + +def test_datetime_both_auto_bad(db): + """Test that setting both auto_now and auto_now_add raises ConfigurationError.""" + with pytest.raises( + ConfigurationError, match="You can choose only 'auto_now' or 'auto_now_add'" + ): + fields.DatetimeField(auto_now=True, auto_now_add=True) + + +@pytest.mark.asyncio +async def test_datetime_create(db): + """Test creating datetime fields and auto_now/auto_now_add behavior.""" + model = testmodels.DatetimeFields + now = timezone.now() + obj0 = await model.create(datetime=now) + obj = await model.get(id=obj0.id) + assert obj.datetime == now + assert obj.datetime_null is None + assert obj.datetime_auto - now < timedelta(microseconds=20000) + assert obj.datetime_add - now < timedelta(microseconds=20000) + datetime_auto = obj.datetime_auto + sleep(0.012) + await obj.save() + obj2 = await model.get(id=obj.id) + assert obj2.datetime == now + assert obj2.datetime_null is None + assert obj2.datetime_auto == obj.datetime_auto + assert obj2.datetime_auto != datetime_auto + assert obj2.datetime_auto - now > timedelta(microseconds=10000) + assert obj2.datetime_auto - now < timedelta(seconds=1) + assert obj2.datetime_add == obj.datetime_add + + +@pytest.mark.asyncio +async def test_datetime_update(db): + """Test updating datetime fields via filter().update().""" + model = testmodels.DatetimeFields + obj0 = await model.create(datetime=datetime(2019, 9, 1, 0, 0, 0, tzinfo=get_default_timezone())) + await model.filter(id=obj0.id).update( + datetime=datetime(2019, 9, 1, 6, 0, 8, tzinfo=get_default_timezone()) + ) + obj = await model.get(id=obj0.id) + assert obj.datetime == datetime(2019, 9, 1, 6, 0, 8, tzinfo=get_default_timezone()) + assert obj.datetime_null is None + + +@pytest.mark.asyncio +async def test_datetime_filter(db): + """Test filtering by datetime field.""" + model = testmodels.DatetimeFields + now = timezone.now() + obj = await model.create(datetime=now) + assert await model.filter(datetime=now).first() == obj + assert await model.annotate(d=F("datetime")).filter(d=now).first() == obj + + +@pytest.mark.asyncio +async def test_datetime_cast(db): + """Test datetime field accepts ISO format string.""" + model = testmodels.DatetimeFields + now = timezone.now() + obj0 = await model.create(datetime=now.isoformat()) + obj = await model.get(id=obj0.id) + assert obj.datetime == now + + +@pytest.mark.asyncio +async def test_datetime_values(db): + """Test datetime field in values() query.""" + model = testmodels.DatetimeFields + now = timezone.now() + obj0 = await model.create(datetime=now) + values = await model.get(id=obj0.id).values("datetime") + assert values["datetime"] == now + + +@pytest.mark.asyncio +async def test_datetime_values_list(db): + """Test datetime field in values_list() query.""" + model = testmodels.DatetimeFields + now = timezone.now() + obj0 = await model.create(datetime=now) + values = await model.get(id=obj0.id).values_list("datetime", flat=True) + assert values == now + + +@pytest.mark.asyncio +async def test_datetime_get_utcnow(db): + """Test getting datetime using UTC now.""" + model = testmodels.DatetimeFields + now = datetime.now(dt_timezone.utc).replace(tzinfo=get_default_timezone()) + await model.create(datetime=now) + obj = await model.get(datetime=now) + assert obj.datetime == now + + +@pytest.mark.asyncio +async def test_datetime_get_now(db): + """Test getting datetime using timezone.now().""" + model = testmodels.DatetimeFields + now = timezone.now() + await model.create(datetime=now) + obj = await model.get(datetime=now) + assert obj.datetime == now + + +@pytest.mark.asyncio +async def test_datetime_count(db): + """Test count queries with datetime fields.""" + model = testmodels.DatetimeFields + now = timezone.now() + obj = await model.create(datetime=now) + assert await model.filter(datetime=obj.datetime).count() == 1 + assert await model.filter(datetime_auto=obj.datetime_auto).count() == 1 + assert await model.filter(datetime_add=obj.datetime_add).count() == 1 + + +@pytest.mark.asyncio +async def test_datetime_default_timezone(db): + """Test default timezone is UTC.""" + model = testmodels.DatetimeFields + now = timezone.now() + obj = await model.create(datetime=now) + assert obj.datetime.tzinfo.zone == "UTC" + + obj_get = await model.get(pk=obj.pk) + assert obj_get.datetime.tzinfo.zone == "UTC" + assert obj_get.datetime == now + + +@pytest.mark.asyncio +async def test_datetime_set_timezone(db): + """Test setting a custom timezone via environment variable.""" + model = testmodels.DatetimeFields + old_tz = os.environ["TIMEZONE"] + tz = "Asia/Shanghai" + os.environ["TIMEZONE"] = tz + now = datetime.now(pytz.timezone(tz)) + obj = await model.create(datetime=now) + assert obj.datetime.tzinfo.zone == tz + + obj_get = await model.get(pk=obj.pk) + assert obj_get.datetime.tzinfo.zone == tz + assert obj_get.datetime == now + + os.environ["TIMEZONE"] = old_tz + + +@pytest.mark.asyncio +async def test_datetime_timezone(db): + """Test timezone handling with USE_TZ enabled.""" + model = testmodels.DatetimeFields + old_tz = os.environ["TIMEZONE"] + old_use_tz = os.environ["USE_TZ"] + tz = "Asia/Shanghai" + os.environ["TIMEZONE"] = tz + os.environ["USE_TZ"] = "True" + + now = datetime.now(pytz.timezone(tz)) + obj = await model.create(datetime=now) + assert obj.datetime.tzinfo.zone == tz + obj_get = await model.get(pk=obj.pk) + assert obj.datetime.tzinfo.zone == tz + assert obj_get.datetime == now + + os.environ["TIMEZONE"] = old_tz + os.environ["USE_TZ"] = old_use_tz + + +@pytest.mark.asyncio +@test.requireCapability(dialect=NotIn("sqlite", "mssql")) +async def test_datetime_filter_by_year_month_day(db): + """Test filtering datetime by year, month, and day.""" + model = testmodels.DatetimeFields + with patch.dict(os.environ, {"USE_TZ": "True"}): + obj = await model.create(datetime=datetime(2024, 1, 2)) + same_year_objs = await model.filter(datetime__year=2024) + filtered_obj = await model.filter( + datetime__year=2024, datetime__month=1, datetime__day=2 + ).first() + assert obj == filtered_obj + assert obj.id in [i.id for i in same_year_objs] + + +# ============================================================================ +# TestTimeFields (sqlite/postgres) -> test_time_* +# ============================================================================ + + +@pytest.mark.asyncio +@test.requireCapability(dialect="sqlite") +@test.requireCapability(dialect="postgres") +async def test_time_create(db): + """Test creating time fields (sqlite/postgres).""" + model = testmodels.TimeFields + now = timezone.now().timetz() + obj0 = await model.create(time=now) + obj1 = await model.get(id=obj0.id) + assert obj1.time == now + -class TestEmpty(test.TestCase): - model: type[Model] = testmodels.DatetimeFields - - async def test_empty(self): - with self.assertRaises(IntegrityError): - await self.model.create() - - -class TestDatetimeFields(TestEmpty): - async def asyncSetUp(self): - await super().asyncSetUp() - timezone._reset_timezone_cache() - - async def asyncTearDown(self): - await super().asyncTearDown() - timezone._reset_timezone_cache() - - def test_both_auto_bad(self): - with self.assertRaisesRegex( - ConfigurationError, "You can choose only 'auto_now' or 'auto_now_add'" - ): - fields.DatetimeField(auto_now=True, auto_now_add=True) - - async def test_create(self): - now = timezone.now() - obj0 = await self.model.create(datetime=now) - obj = await self.model.get(id=obj0.id) - self.assertEqual(obj.datetime, now) - self.assertEqual(obj.datetime_null, None) - self.assertLess(obj.datetime_auto - now, timedelta(microseconds=20000)) - self.assertLess(obj.datetime_add - now, timedelta(microseconds=20000)) - datetime_auto = obj.datetime_auto - sleep(0.012) - await obj.save() - obj2 = await self.model.get(id=obj.id) - self.assertEqual(obj2.datetime, now) - self.assertEqual(obj2.datetime_null, None) - self.assertEqual(obj2.datetime_auto, obj.datetime_auto) - self.assertNotEqual(obj2.datetime_auto, datetime_auto) - self.assertGreater(obj2.datetime_auto - now, timedelta(microseconds=10000)) - self.assertLess(obj2.datetime_auto - now, timedelta(seconds=1)) - self.assertEqual(obj2.datetime_add, obj.datetime_add) - - async def test_update(self): - obj0 = await self.model.create( - datetime=datetime(2019, 9, 1, 0, 0, 0, tzinfo=get_default_timezone()) - ) - await self.model.filter(id=obj0.id).update( - datetime=datetime(2019, 9, 1, 6, 0, 8, tzinfo=get_default_timezone()) - ) - obj = await self.model.get(id=obj0.id) - self.assertEqual(obj.datetime, datetime(2019, 9, 1, 6, 0, 8, tzinfo=get_default_timezone())) - self.assertEqual(obj.datetime_null, None) - - async def test_filter(self): - now = timezone.now() - obj = await self.model.create(datetime=now) - self.assertEqual(await self.model.filter(datetime=now).first(), obj) - self.assertEqual(await self.model.annotate(d=F("datetime")).filter(d=now).first(), obj) - - async def test_cast(self): - now = timezone.now() - obj0 = await self.model.create(datetime=now.isoformat()) - obj = await self.model.get(id=obj0.id) - self.assertEqual(obj.datetime, now) - - async def test_values(self): - now = timezone.now() - obj0 = await self.model.create(datetime=now) - values = await self.model.get(id=obj0.id).values("datetime") - self.assertEqual(values["datetime"], now) - - async def test_values_list(self): - now = timezone.now() - obj0 = await self.model.create(datetime=now) - values = await self.model.get(id=obj0.id).values_list("datetime", flat=True) - self.assertEqual(values, now) - - async def test_get_utcnow(self): - now = datetime.now(dt_timezone.utc).replace(tzinfo=get_default_timezone()) - await self.model.create(datetime=now) - obj = await self.model.get(datetime=now) - self.assertEqual(obj.datetime, now) - - async def test_get_now(self): - now = timezone.now() - await self.model.create(datetime=now) - obj = await self.model.get(datetime=now) - self.assertEqual(obj.datetime, now) - - async def test_count(self): - now = timezone.now() - obj = await self.model.create(datetime=now) - self.assertEqual(await self.model.filter(datetime=obj.datetime).count(), 1) - self.assertEqual(await self.model.filter(datetime_auto=obj.datetime_auto).count(), 1) - self.assertEqual(await self.model.filter(datetime_add=obj.datetime_add).count(), 1) - - async def test_default_timezone(self): - now = timezone.now() - obj = await self.model.create(datetime=now) - self.assertEqual(obj.datetime.tzinfo.zone, "UTC") - - obj_get = await self.model.get(pk=obj.pk) - self.assertEqual(obj_get.datetime.tzinfo.zone, "UTC") - self.assertEqual(obj_get.datetime, now) - - async def test_set_timezone(self): - old_tz = os.environ["TIMEZONE"] - tz = "Asia/Shanghai" - os.environ["TIMEZONE"] = tz - now = datetime.now(pytz.timezone(tz)) - obj = await self.model.create(datetime=now) - self.assertEqual(obj.datetime.tzinfo.zone, tz) - - obj_get = await self.model.get(pk=obj.pk) - self.assertEqual(obj_get.datetime.tzinfo.zone, tz) - self.assertEqual(obj_get.datetime, now) - - os.environ["TIMEZONE"] = old_tz - - async def test_timezone(self): - old_tz = os.environ["TIMEZONE"] - old_use_tz = os.environ["USE_TZ"] - tz = "Asia/Shanghai" - os.environ["TIMEZONE"] = tz - os.environ["USE_TZ"] = "True" - - now = datetime.now(pytz.timezone(tz)) - obj = await self.model.create(datetime=now) - self.assertEqual(obj.datetime.tzinfo.zone, tz) - obj_get = await self.model.get(pk=obj.pk) - self.assertEqual(obj.datetime.tzinfo.zone, tz) - self.assertEqual(obj_get.datetime, now) - - os.environ["TIMEZONE"] = old_tz - os.environ["USE_TZ"] = old_use_tz - - @test.requireCapability(dialect=NotIn("sqlite", "mssql")) - async def test_filter_by_year_month_day(self): - with patch.dict(os.environ, {"USE_TZ": "True"}): - obj = await self.model.create(datetime=datetime(2024, 1, 2)) - same_year_objs = await self.model.filter(datetime__year=2024) - filtered_obj = await self.model.filter( - datetime__year=2024, datetime__month=1, datetime__day=2 - ).first() - assert obj == filtered_obj - assert obj.id in [i.id for i in same_year_objs] - - -class TestTime(test.TestCase): +@pytest.mark.asyncio +@test.requireCapability(dialect="sqlite") +@test.requireCapability(dialect="postgres") +async def test_time_cast(db): + """Test time field accepts ISO format string (sqlite/postgres).""" model = testmodels.TimeFields + obj0 = await model.create(time="21:00+00:00") + obj1 = await model.get(id=obj0.id) + assert obj1.time == time.fromisoformat("21:00+00:00") +@pytest.mark.asyncio @test.requireCapability(dialect="sqlite") @test.requireCapability(dialect="postgres") -class TestTimeFields(TestTime): - async def test_create(self): - now = timezone.now().timetz() - obj0 = await self.model.create(time=now) - boj1 = await self.model.get(id=obj0.id) - self.assertEqual(boj1.time, now) - - async def test_cast(self): - obj0 = await self.model.create(time="21:00+00:00") - obj1 = await self.model.get(id=obj0.id) - self.assertEqual(obj1.time, time.fromisoformat("21:00+00:00")) - - async def test_values(self): - now = timezone.now().timetz() - obj0 = await self.model.create(time=now) - values = await self.model.get(id=obj0.id).values("time") - self.assertEqual(values["time"], now) - - async def test_values_list(self): - now = timezone.now().timetz() - obj0 = await self.model.create(time=now) - values = await self.model.get(id=obj0.id).values_list("time", flat=True) - self.assertEqual(values, now) - - async def test_get(self): - now = timezone.now().timetz() - await self.model.create(time=now) - obj = await self.model.get(time=now) - self.assertEqual(obj.time, now) +async def test_time_values(db): + """Test time field in values() query (sqlite/postgres).""" + model = testmodels.TimeFields + now = timezone.now().timetz() + obj0 = await model.create(time=now) + values = await model.get(id=obj0.id).values("time") + assert values["time"] == now +@pytest.mark.asyncio +@test.requireCapability(dialect="sqlite") +@test.requireCapability(dialect="postgres") +async def test_time_values_list(db): + """Test time field in values_list() query (sqlite/postgres).""" + model = testmodels.TimeFields + now = timezone.now().timetz() + obj0 = await model.create(time=now) + values = await model.get(id=obj0.id).values_list("time", flat=True) + assert values == now + + +@pytest.mark.asyncio +@test.requireCapability(dialect="sqlite") +@test.requireCapability(dialect="postgres") +async def test_time_get(db): + """Test getting by time field (sqlite/postgres).""" + model = testmodels.TimeFields + now = timezone.now().timetz() + await model.create(time=now) + obj = await model.get(time=now) + assert obj.time == now + + +# ============================================================================ +# TestTimeFieldsMySQL -> test_time_mysql_* +# ============================================================================ + + +@pytest.mark.asyncio +@test.requireCapability(dialect="mysql") +async def test_time_mysql_create(db): + """Test creating time fields (mysql returns timedelta).""" + model = testmodels.TimeFields + now = timezone.now().timetz() + obj0 = await model.create(time=now) + obj1 = await model.get(id=obj0.id) + assert obj1.time == timedelta( + hours=now.hour, + minutes=now.minute, + seconds=now.second, + microseconds=now.microsecond, + ) + + +@pytest.mark.asyncio +@test.requireCapability(dialect="mysql") +async def test_time_mysql_cast(db): + """Test time field accepts ISO format string (mysql returns timedelta).""" + model = testmodels.TimeFields + obj0 = await model.create(time="21:00+00:00") + obj1 = await model.get(id=obj0.id) + t = time.fromisoformat("21:00+00:00") + assert obj1.time == timedelta( + hours=t.hour, + minutes=t.minute, + seconds=t.second, + microseconds=t.microsecond, + ) + + +@pytest.mark.asyncio +@test.requireCapability(dialect="mysql") +async def test_time_mysql_values(db): + """Test time field in values() query (mysql returns timedelta).""" + model = testmodels.TimeFields + now = timezone.now().timetz() + obj0 = await model.create(time=now) + values = await model.get(id=obj0.id).values("time") + assert values["time"] == timedelta( + hours=now.hour, + minutes=now.minute, + seconds=now.second, + microseconds=now.microsecond, + ) + + +@pytest.mark.asyncio @test.requireCapability(dialect="mysql") -class TestTimeFieldsMySQL(TestTime): - async def test_create(self): - now = timezone.now().timetz() - obj0 = await self.model.create(time=now) - boj1 = await self.model.get(id=obj0.id) - self.assertEqual( - boj1.time, - timedelta( - hours=now.hour, - minutes=now.minute, - seconds=now.second, - microseconds=now.microsecond, - ), - ) - - async def test_cast(self): - obj0 = await self.model.create(time="21:00+00:00") - obj1 = await self.model.get(id=obj0.id) - t = time.fromisoformat("21:00+00:00") - self.assertEqual( - obj1.time, - timedelta( - hours=t.hour, - minutes=t.minute, - seconds=t.second, - microseconds=t.microsecond, - ), - ) - - async def test_values(self): - now = timezone.now().timetz() - obj0 = await self.model.create(time=now) - values = await self.model.get(id=obj0.id).values("time") - self.assertEqual( - values["time"], - timedelta( - hours=now.hour, - minutes=now.minute, - seconds=now.second, - microseconds=now.microsecond, - ), - ) - - async def test_values_list(self): - now = timezone.now().timetz() - obj0 = await self.model.create(time=now) - values = await self.model.get(id=obj0.id).values_list("time", flat=True) - self.assertEqual( - values, - timedelta( - hours=now.hour, - minutes=now.minute, - seconds=now.second, - microseconds=now.microsecond, - ), - ) - - async def test_get(self): - now = timezone.now().timetz() - await self.model.create(time=now) - obj = await self.model.get(time=now) - self.assertEqual( - obj.time, - timedelta( - hours=now.hour, - minutes=now.minute, - seconds=now.second, - microseconds=now.microsecond, - ), - ) - - -class TestDateFields(TestEmpty): +async def test_time_mysql_values_list(db): + """Test time field in values_list() query (mysql returns timedelta).""" + model = testmodels.TimeFields + now = timezone.now().timetz() + obj0 = await model.create(time=now) + values = await model.get(id=obj0.id).values_list("time", flat=True) + assert values == timedelta( + hours=now.hour, + minutes=now.minute, + seconds=now.second, + microseconds=now.microsecond, + ) + + +@pytest.mark.asyncio +@test.requireCapability(dialect="mysql") +async def test_time_mysql_get(db): + """Test getting by time field (mysql returns timedelta).""" + model = testmodels.TimeFields + now = timezone.now().timetz() + await model.create(time=now) + obj = await model.get(time=now) + assert obj.time == timedelta( + hours=now.hour, + minutes=now.minute, + seconds=now.second, + microseconds=now.microsecond, + ) + + +# ============================================================================ +# TestDateFields -> test_date_* +# ============================================================================ + + +@pytest.mark.asyncio +async def test_empty_date_fields(db): + """Test that creating DateFields without required field raises IntegrityError.""" + with pytest.raises(IntegrityError): + await testmodels.DateFields.create() + + +@pytest.mark.asyncio +async def test_date_create(db): + """Test creating date fields.""" + model = testmodels.DateFields + today = date.today() + obj0 = await model.create(date=today) + obj = await model.get(id=obj0.id) + assert obj.date == today + assert obj.date_null is None + await obj.save() + obj2 = await model.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_date_cast(db): + """Test date field accepts ISO format string.""" + model = testmodels.DateFields + today = date.today() + obj0 = await model.create(date=today.isoformat()) + obj = await model.get(id=obj0.id) + assert obj.date == today + + +@pytest.mark.asyncio +async def test_date_values(db): + """Test date field in values() query.""" + model = testmodels.DateFields + today = date.today() + obj0 = await model.create(date=today) + values = await model.get(id=obj0.id).values("date") + assert values["date"] == today + + +@pytest.mark.asyncio +async def test_date_values_list(db): + """Test date field in values_list() query.""" + model = testmodels.DateFields + today = date.today() + obj0 = await model.create(date=today) + values = await model.get(id=obj0.id).values_list("date", flat=True) + assert values == today + + +@pytest.mark.asyncio +async def test_date_get(db): + """Test getting by date field.""" model = testmodels.DateFields + today = date.today() + await model.create(date=today) + obj = await model.get(date=today) + assert obj.date == today - async def test_create(self): - today = date.today() - obj0 = await self.model.create(date=today) - obj = await self.model.get(id=obj0.id) - self.assertEqual(obj.date, today) - self.assertEqual(obj.date_null, None) - await obj.save() - obj2 = await self.model.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_cast(self): - today = date.today() - obj0 = await self.model.create(date=today.isoformat()) - obj = await self.model.get(id=obj0.id) - self.assertEqual(obj.date, today) - - async def test_values(self): - today = date.today() - obj0 = await self.model.create(date=today) - values = await self.model.get(id=obj0.id).values("date") - self.assertEqual(values["date"], today) - - async def test_values_list(self): - today = date.today() - obj0 = await self.model.create(date=today) - values = await self.model.get(id=obj0.id).values_list("date", flat=True) - self.assertEqual(values, today) - - async def test_get(self): - today = date.today() - await self.model.create(date=today) - obj = await self.model.get(date=today) - self.assertEqual(obj.date, today) - - async def test_date_str(self): - obj0 = await self.model.create(date="2020-08-17") - obj1 = await self.model.get(date="2020-08-17") - self.assertEqual(obj0.date, obj1.date) - with self.assertRaises((ParseError, ValueError)): - await self.model.create(date="2020-08-xx") - await self.model.filter(date="2020-08-17").update(date="2020-08-18") - obj2 = await self.model.get(date="2020-08-18") - self.assertEqual(obj2.date, date(year=2020, month=8, day=18)) - - -class TestTimeDeltaFields(TestEmpty): + +@pytest.mark.asyncio +async def test_date_str(db): + """Test date field with string input and filtering/updating.""" + model = testmodels.DateFields + obj0 = await model.create(date="2020-08-17") + obj1 = await model.get(date="2020-08-17") + assert obj0.date == obj1.date + with pytest.raises((ParseError, ValueError)): + await model.create(date="2020-08-xx") + await model.filter(date="2020-08-17").update(date="2020-08-18") + obj2 = await model.get(date="2020-08-18") + assert obj2.date == date(year=2020, month=8, day=18) + + +# ============================================================================ +# TestTimeDeltaFields -> test_timedelta_* +# ============================================================================ + + +@pytest.mark.asyncio +async def test_empty_timedelta_fields(db): + """Test that creating TimeDeltaFields without required field raises IntegrityError.""" + with pytest.raises(IntegrityError): + await testmodels.TimeDeltaFields.create() + + +@pytest.mark.asyncio +async def test_timedelta_create(db): + """Test creating timedelta fields.""" + model = testmodels.TimeDeltaFields + obj0 = await model.create(timedelta=timedelta(days=35, seconds=8, microseconds=1)) + obj = await model.get(id=obj0.id) + assert obj.timedelta == timedelta(days=35, seconds=8, microseconds=1) + assert obj.timedelta_null is None + await obj.save() + obj2 = await model.get(id=obj.id) + assert obj == obj2 + + +@pytest.mark.asyncio +async def test_timedelta_values(db): + """Test timedelta field in values() query.""" model = testmodels.TimeDeltaFields + obj0 = await model.create(timedelta=timedelta(days=35, seconds=8, microseconds=1)) + values = await model.get(id=obj0.id).values("timedelta") + assert values["timedelta"] == timedelta(days=35, seconds=8, microseconds=1) - async def test_create(self): - obj0 = await self.model.create(timedelta=timedelta(days=35, seconds=8, microseconds=1)) - obj = await self.model.get(id=obj0.id) - self.assertEqual(obj.timedelta, timedelta(days=35, seconds=8, microseconds=1)) - self.assertEqual(obj.timedelta_null, None) - await obj.save() - obj2 = await self.model.get(id=obj.id) - self.assertEqual(obj, obj2) - - async def test_values(self): - obj0 = await self.model.create(timedelta=timedelta(days=35, seconds=8, microseconds=1)) - values = await self.model.get(id=obj0.id).values("timedelta") - self.assertEqual(values["timedelta"], timedelta(days=35, seconds=8, microseconds=1)) - - async def test_values_list(self): - obj0 = await self.model.create(timedelta=timedelta(days=35, seconds=8, microseconds=1)) - values = await self.model.get(id=obj0.id).values_list("timedelta", flat=True) - self.assertEqual(values, timedelta(days=35, seconds=8, microseconds=1)) - - async def test_get(self): - delta = timedelta(days=35, seconds=8, microseconds=2) - await self.model.create(timedelta=delta) - obj = await self.model.get(timedelta=delta) - self.assertEqual(obj.timedelta, delta) + +@pytest.mark.asyncio +async def test_timedelta_values_list(db): + """Test timedelta field in values_list() query.""" + model = testmodels.TimeDeltaFields + obj0 = await model.create(timedelta=timedelta(days=35, seconds=8, microseconds=1)) + values = await model.get(id=obj0.id).values_list("timedelta", flat=True) + assert values == timedelta(days=35, seconds=8, microseconds=1) + + +@pytest.mark.asyncio +async def test_timedelta_get(db): + """Test getting by timedelta field.""" + model = testmodels.TimeDeltaFields + delta = timedelta(days=35, seconds=8, microseconds=2) + await model.create(timedelta=delta) + obj = await model.get(timedelta=delta) + assert obj.timedelta == delta diff --git a/tests/fields/test_uuid.py b/tests/fields/test_uuid.py index 7f03b8ca9..3901a1125 100644 --- a/tests/fields/test_uuid.py +++ b/tests/fields/test_uuid.py @@ -1,46 +1,53 @@ import uuid +import pytest + from tests import testmodels -from tortoise.contrib import test from tortoise.exceptions import IntegrityError -class TestUUIDFields(test.TestCase): - async def test_empty(self): - with self.assertRaises(IntegrityError): - await testmodels.UUIDFields.create() - - async def test_create(self): - data = uuid.uuid4() - obj0 = await testmodels.UUIDFields.create(data=data) - self.assertIsInstance(obj0.data, uuid.UUID) - self.assertIsInstance(obj0.data_auto, uuid.UUID) - self.assertEqual(obj0.data_null, None) - obj = await testmodels.UUIDFields.get(id=obj0.id) - self.assertIsInstance(obj.data, uuid.UUID) - self.assertIsInstance(obj.data_auto, uuid.UUID) - self.assertEqual(obj.data, data) - self.assertEqual(obj.data_null, None) - await obj.save() - obj2 = await testmodels.UUIDFields.get(id=obj.id) - self.assertEqual(obj, obj2) - - await obj.delete() - obj = await testmodels.UUIDFields.filter(id=obj0.id).first() - self.assertEqual(obj, None) - - async def test_update(self): - data = uuid.uuid4() - data2 = uuid.uuid4() - obj0 = await testmodels.UUIDFields.create(data=data) - await testmodels.UUIDFields.filter(id=obj0.id).update(data=data2) - obj = await testmodels.UUIDFields.get(id=obj0.id) - self.assertEqual(obj.data, data2) - self.assertEqual(obj.data_null, None) - - async def test_create_not_null(self): - data = uuid.uuid4() - obj0 = await testmodels.UUIDFields.create(data=data, data_null=data) - obj = await testmodels.UUIDFields.get(id=obj0.id) - self.assertEqual(obj.data, data) - self.assertEqual(obj.data_null, data) +@pytest.mark.asyncio +async def test_empty(db): + with pytest.raises(IntegrityError): + await testmodels.UUIDFields.create() + + +@pytest.mark.asyncio +async def test_create(db): + data = uuid.uuid4() + obj0 = await testmodels.UUIDFields.create(data=data) + assert isinstance(obj0.data, uuid.UUID) + assert isinstance(obj0.data_auto, uuid.UUID) + assert obj0.data_null is None + obj = await testmodels.UUIDFields.get(id=obj0.id) + assert isinstance(obj.data, uuid.UUID) + assert isinstance(obj.data_auto, uuid.UUID) + assert obj.data == data + assert obj.data_null is None + await obj.save() + obj2 = await testmodels.UUIDFields.get(id=obj.id) + assert obj == obj2 + + await obj.delete() + obj = await testmodels.UUIDFields.filter(id=obj0.id).first() + assert obj is None + + +@pytest.mark.asyncio +async def test_update(db): + data = uuid.uuid4() + data2 = uuid.uuid4() + obj0 = await testmodels.UUIDFields.create(data=data) + await testmodels.UUIDFields.filter(id=obj0.id).update(data=data2) + obj = await testmodels.UUIDFields.get(id=obj0.id) + assert obj.data == data2 + assert obj.data_null is None + + +@pytest.mark.asyncio +async def test_create_not_null(db): + data = uuid.uuid4() + obj0 = await testmodels.UUIDFields.create(data=data, data_null=data) + obj = await testmodels.UUIDFields.get(id=obj0.id) + assert obj.data == data + assert obj.data_null == data diff --git a/tests/migrations/test_runtime_migrations.py b/tests/migrations/test_runtime_migrations.py index 8ba264d0b..3a3d39bde 100644 --- a/tests/migrations/test_runtime_migrations.py +++ b/tests/migrations/test_runtime_migrations.py @@ -6,8 +6,8 @@ import pytest -from tortoise import connections from tortoise.backends.base.client import BaseDBAsyncClient, Capabilities +from tortoise.context import TortoiseContext from tortoise.migrations.executor import MigrationExecutor, MigrationTarget from tortoise.migrations.graph import MigrationGraph, MigrationKey from tortoise.migrations.loader import MigrationLoader @@ -428,19 +428,15 @@ async def test_runpython_historical_models_survive_schema_change( module_path = _write_runpython_migrations(tmp_path, "blog") monkeypatch.syspath_prepend(str(tmp_path)) - old_config = connections._db_config - old_create_db = connections._create_db - connections._clear_storage() - connections._init_config( - { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": str(tmp_path / "runpython.sqlite3")}, + async with TortoiseContext() as ctx: + ctx.connections._init_config( + { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": str(tmp_path / "runpython.sqlite3")}, + } } - } - ) - connection = None - try: + ) apps_config = { "blog": { "models": [], @@ -448,7 +444,7 @@ async def test_runpython_historical_models_survive_schema_change( "migrations": module_path, } } - connection = connections.get("default") + connection = ctx.connections.get("default") executor = MigrationExecutor(connection, apps_config) await executor.migrate() @@ -460,9 +456,5 @@ async def test_runpython_historical_models_survive_schema_change( await executor.migrate([MigrationTarget(app_label="blog", name="__latest__")]) assert module.CALLS == ["forward", "reverse", "forward"] - finally: - if connection is not None: - await connection.close() - connections._clear_storage() - connections._db_config = old_config - connections._create_db = old_create_db + + await connection.close() diff --git a/tests/migrations/test_state_generator.py b/tests/migrations/test_state_generator.py index 38bf18321..f2b662aaa 100644 --- a/tests/migrations/test_state_generator.py +++ b/tests/migrations/test_state_generator.py @@ -2,7 +2,8 @@ from typing import Any, cast -from tortoise import connections, fields +from tortoise import fields +from tortoise.context import TortoiseContext from tortoise.fields.relational import ForeignKeyFieldInstance from tortoise.migrations.operations import CreateModel from tortoise.migrations.schema_generator.state import ModelState, State @@ -42,18 +43,15 @@ def test_field_signature_ignores_implicit_db_column() -> None: def test_state_apps_builds_relations_before_querysets() -> None: - old_config = connections._db_config - old_create_db = connections._create_db - connections._clear_storage() - connections._init_config( - { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, + with TortoiseContext() as ctx: + ctx.connections._init_config( + { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } } - } - ) - try: + ) state = State(models={}, apps=StateApps(default_connections={"blog": "default"})) CreateModel( name="Author", @@ -70,7 +68,3 @@ def test_state_apps_builds_relations_before_querysets() -> None: post_model = state.apps.get_model("blog.Post") author_field = cast(ForeignKeyFieldInstance, post_model._meta.fields_map["author"]) assert author_field.to_field_instance is not None - finally: - connections._clear_storage() - connections._db_config = old_config - connections._create_db = old_create_db diff --git a/tests/model_setup/test__models__.py b/tests/model_setup/test__models__.py index c5427c219..a9e483ac8 100644 --- a/tests/model_setup/test__models__.py +++ b/tests/model_setup/test__models__.py @@ -2,60 +2,80 @@ Tests for __models__ """ +import os import re from unittest.mock import AsyncMock, patch +import pytest + from tortoise import Tortoise, connections -from tortoise.contrib import test +from tortoise.backends.base.config_generator import generate_config from tortoise.exceptions import ConfigurationError from tortoise.utils import get_schema_sql -class TestGenerateSchema(test.SimpleTestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - try: - Tortoise.apps = None - Tortoise._inited = False - except ConfigurationError: - pass +async def _reset_tortoise(): + """Helper to reset Tortoise state before each test.""" + try: + Tortoise.apps = None Tortoise._inited = False - self.sqls = "" - self.post_sqls = "" - self.engine = test.getDBConfig(app_label="models", modules=[])["connections"]["models"][ - "engine" - ] - - async def init_for(self, module: str, safe=False) -> None: - if self.engine != "tortoise.backends.sqlite": - raise test.SkipTest("sqlite only") - with patch( - "tortoise.backends.sqlite.client.SqliteClient.create_connection", new=AsyncMock() - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": {"models": {"models": [module], "default_connection": "default"}}, - } - ) - self.sqls = get_schema_sql(connections.get("default"), safe).split(";\n") - - def get_sql(self, text: str) -> str: - return str(re.sub(r"[ \t\n\r]+", " ", [sql for sql in self.sqls if text in sql][0])) - - async def test_good(self): - await self.init_for("tests.model_setup.models__models__good") - self.assertIn("goodtournament", "; ".join(self.sqls)) - self.assertIn("inaclasstournament", "; ".join(self.sqls)) - self.assertNotIn("badtournament", "; ".join(self.sqls)) - - async def test_bad(self): - await self.init_for("tests.model_setup.models__models__bad") - self.assertNotIn("goodtournament", "; ".join(self.sqls)) - self.assertNotIn("inaclasstournament", "; ".join(self.sqls)) - self.assertIn("badtournament", "; ".join(self.sqls)) + except ConfigurationError: + pass + Tortoise._inited = False + + +def _get_engine() -> str: + """Get the current test engine.""" + db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") + config = generate_config(db_url, app_modules={"models": []}, connection_label="models") + return config["connections"]["models"]["engine"] + + +async def _init_for(module: str, safe: bool = False) -> list[str]: + """ + Initialize Tortoise for a specific module and return SQL statements. + + Raises SkipTest if not using sqlite. + """ + engine = _get_engine() + if engine != "tortoise.backends.sqlite": + pytest.skip("sqlite only") + + with patch("tortoise.backends.sqlite.client.SqliteClient.create_connection", new=AsyncMock()): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": {"models": {"models": [module], "default_connection": "default"}}, + } + ) + return get_schema_sql(connections.get("default"), safe).split(";\n") + + +def _get_sql(sqls: list[str], text: str) -> str: + """Get SQL statement containing the given text.""" + return str(re.sub(r"[ \t\n\r]+", " ", [sql for sql in sqls if text in sql][0])) + + +@pytest.mark.asyncio +async def test_good(): + await _reset_tortoise() + sqls = await _init_for("tests.model_setup.models__models__good") + sql_joined = "; ".join(sqls) + assert "goodtournament" in sql_joined + assert "inaclasstournament" in sql_joined + assert "badtournament" not in sql_joined + + +@pytest.mark.asyncio +async def test_bad(): + await _reset_tortoise() + sqls = await _init_for("tests.model_setup.models__models__bad") + sql_joined = "; ".join(sqls) + assert "goodtournament" not in sql_joined + assert "inaclasstournament" not in sql_joined + assert "badtournament" in sql_joined diff --git a/tests/model_setup/test_bad_relation_reference.py b/tests/model_setup/test_bad_relation_reference.py index d403e43f0..55c1f18aa 100644 --- a/tests/model_setup/test_bad_relation_reference.py +++ b/tests/model_setup/test_bad_relation_reference.py @@ -1,24 +1,55 @@ +import pytest + from tortoise import Tortoise -from tortoise.contrib import test +from tortoise.context import TortoiseContext, get_current_context from tortoise.exceptions import ConfigurationError +# Save original classproperties before any test can shadow them +_original_apps_prop = Tortoise.__dict__["apps"] +_original_inited_prop = Tortoise.__dict__["_inited"] + + +async def _reset_tortoise(): + """Helper to reset Tortoise state before each test. + + Note: We MUST NOT set Tortoise.apps = None or Tortoise._inited = False + because these are classproperties and setting them shadows the property + with a class attribute, breaking future access. + """ + # Restore original classproperties if they were shadowed + if not isinstance(Tortoise.__dict__.get("apps"), type(_original_apps_prop)): + type.__setattr__(Tortoise, "apps", _original_apps_prop) + if not isinstance(Tortoise.__dict__.get("_inited"), type(_original_inited_prop)): + type.__setattr__(Tortoise, "_inited", _original_inited_prop) -class TestBadRelationReferenceErrors(test.SimpleTestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - try: - Tortoise.apps = None - Tortoise._inited = False - except ConfigurationError: - pass - Tortoise._inited = False + # Get the current context and properly reset it + ctx = get_current_context() + if ctx is not None: + # Clear db_config first to prevent close_all from trying to import bad backends + if ctx._connections is not None: + # Clear storage without closing (to avoid importing bad backends) + ctx._connections._storage.clear() + ctx._connections._db_config = None + ctx._connections = None + ctx._apps = None + ctx._inited = False + ctx._default_connection = None + else: + # No context exists - create one for the test + ctx = TortoiseContext() + ctx.__enter__() - async def asyncTearDown(self) -> None: - await Tortoise._reset_apps() - await super().asyncTearDown() - async def test_wrong_app_init(self): - with self.assertRaisesRegex(ConfigurationError, "No app with name 'app' registered."): +async def _teardown_tortoise(): + """Helper to teardown Tortoise state after each test.""" + await Tortoise._reset_apps() + + +@pytest.mark.asyncio +async def test_wrong_app_init(): + await _reset_tortoise() + try: + with pytest.raises(ConfigurationError, match="No app with name 'app' registered."): await Tortoise.init( { "connections": { @@ -35,10 +66,16 @@ async def test_wrong_app_init(self): }, } ) + finally: + await _teardown_tortoise() + - async def test_wrong_model_init(self): - with self.assertRaisesRegex( - ConfigurationError, "No model with name 'Tour' registered in app 'models'." +@pytest.mark.asyncio +async def test_wrong_model_init(): + await _reset_tortoise() + try: + with pytest.raises( + ConfigurationError, match="No model with name 'Tour' registered in app 'models'." ): await Tortoise.init( { @@ -56,10 +93,16 @@ async def test_wrong_model_init(self): }, } ) + finally: + await _teardown_tortoise() - async def test_no_app_in_reference_init(self): - with self.assertRaisesRegex( - ConfigurationError, 'ForeignKeyField accepts model name in format "app.Model"' + +@pytest.mark.asyncio +async def test_no_app_in_reference_init(): + await _reset_tortoise() + try: + with pytest.raises( + ConfigurationError, match='ForeignKeyField accepts model name in format "app.Model"' ): await Tortoise.init( { @@ -77,10 +120,16 @@ async def test_no_app_in_reference_init(self): }, } ) + finally: + await _teardown_tortoise() + - async def test_more_than_two_dots_in_reference_init(self): - with self.assertRaisesRegex( - ConfigurationError, 'ForeignKeyField accepts model name in format "app.Model"' +@pytest.mark.asyncio +async def test_more_than_two_dots_in_reference_init(): + await _reset_tortoise() + try: + with pytest.raises( + ConfigurationError, match='ForeignKeyField accepts model name in format "app.Model"' ): await Tortoise.init( { @@ -98,10 +147,16 @@ async def test_more_than_two_dots_in_reference_init(self): }, } ) + finally: + await _teardown_tortoise() - async def test_no_app_in_o2o_reference_init(self): - with self.assertRaisesRegex( - ConfigurationError, 'OneToOneField accepts model name in format "app.Model"' + +@pytest.mark.asyncio +async def test_no_app_in_o2o_reference_init(): + await _reset_tortoise() + try: + with pytest.raises( + ConfigurationError, match='OneToOneField accepts model name in format "app.Model"' ): await Tortoise.init( { @@ -119,10 +174,16 @@ async def test_no_app_in_o2o_reference_init(self): }, } ) + finally: + await _teardown_tortoise() + - async def test_non_unique_field_in_fk_reference_init(self): - with self.assertRaisesRegex( - ConfigurationError, 'field "uuid" in model "Tournament" is not unique' +@pytest.mark.asyncio +async def test_non_unique_field_in_fk_reference_init(): + await _reset_tortoise() + try: + with pytest.raises( + ConfigurationError, match='field "uuid" in model "Tournament" is not unique' ): await Tortoise.init( { @@ -140,10 +201,16 @@ async def test_non_unique_field_in_fk_reference_init(self): }, } ) + finally: + await _teardown_tortoise() - async def test_non_exist_field_in_fk_reference_init(self): - with self.assertRaisesRegex( - ConfigurationError, 'there is no field named "uuids" in model "Tournament"' + +@pytest.mark.asyncio +async def test_non_exist_field_in_fk_reference_init(): + await _reset_tortoise() + try: + with pytest.raises( + ConfigurationError, match='there is no field named "uuids" in model "Tournament"' ): await Tortoise.init( { @@ -161,10 +228,16 @@ async def test_non_exist_field_in_fk_reference_init(self): }, } ) + finally: + await _teardown_tortoise() + - async def test_non_unique_field_in_o2o_reference_init(self): - with self.assertRaisesRegex( - ConfigurationError, 'field "uuid" in model "Tournament" is not unique' +@pytest.mark.asyncio +async def test_non_unique_field_in_o2o_reference_init(): + await _reset_tortoise() + try: + with pytest.raises( + ConfigurationError, match='field "uuid" in model "Tournament" is not unique' ): await Tortoise.init( { @@ -182,10 +255,16 @@ async def test_non_unique_field_in_o2o_reference_init(self): }, } ) + finally: + await _teardown_tortoise() + - async def test_non_exist_field_in_o2o_reference_init(self): - with self.assertRaisesRegex( - ConfigurationError, 'there is no field named "uuids" in model "Tournament"' +@pytest.mark.asyncio +async def test_non_exist_field_in_o2o_reference_init(): + await _reset_tortoise() + try: + with pytest.raises( + ConfigurationError, match='there is no field named "uuids" in model "Tournament"' ): await Tortoise.init( { @@ -203,3 +282,5 @@ async def test_non_exist_field_in_o2o_reference_init(self): }, } ) + finally: + await _teardown_tortoise() diff --git a/tests/model_setup/test_init.py b/tests/model_setup/test_init.py index 8549eca65..b2f619b30 100644 --- a/tests/model_setup/test_init.py +++ b/tests/model_setup/test_init.py @@ -1,23 +1,94 @@ import os from unittest.mock import patch +import pytest + from tortoise import Tortoise, connections from tortoise.config import AppConfig, ConnectionConfig, TortoiseConfig -from tortoise.contrib import test +from tortoise.context import TortoiseContext, get_current_context from tortoise.exceptions import ConfigurationError +# Save original classproperties before any test can shadow them +_original_apps_prop = Tortoise.__dict__["apps"] +_original_inited_prop = Tortoise.__dict__["_inited"] + + +async def _reset_tortoise(): + """Helper to reset Tortoise state before each test. + + Note: We MUST NOT set Tortoise.apps = None or Tortoise._inited = False + because these are classproperties and setting them shadows the property + with a class attribute, breaking future access. + """ + # Restore original classproperties if they were shadowed + if not isinstance(Tortoise.__dict__.get("apps"), type(_original_apps_prop)): + type.__setattr__(Tortoise, "apps", _original_apps_prop) + if not isinstance(Tortoise.__dict__.get("_inited"), type(_original_inited_prop)): + type.__setattr__(Tortoise, "_inited", _original_inited_prop) -class TestInitErrors(test.SimpleTestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - try: - Tortoise.apps = None - Tortoise._inited = False - except ConfigurationError: - pass - Tortoise._inited = False + # Get the current context and properly reset it + ctx = get_current_context() + if ctx is not None: + # Clear db_config first to prevent close_all from trying to import bad backends + if ctx._connections is not None: + # Clear storage without closing (to avoid importing bad backends) + ctx._connections._storage.clear() + ctx._connections._db_config = None + ctx._connections = None + ctx._apps = None + ctx._inited = False + ctx._default_connection = None + else: + # No context exists - create one for the test + ctx = TortoiseContext() + ctx.__enter__() + + +@pytest.mark.asyncio +async def test_basic_init(): + await _reset_tortoise() + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": {"models": {"models": ["tests.testmodels"], "default_connection": "default"}}, + } + ) + assert "models" in Tortoise.apps + assert connections.get("default") is not None - async def test_basic_init(self): + +@pytest.mark.asyncio +async def test_dataclass_init(): + await _reset_tortoise() + await Tortoise.init( + config=TortoiseConfig( + connections={ + "default": ConnectionConfig( + engine="tortoise.backends.sqlite", + credentials={"file_path": ":memory:"}, + ) + }, + apps={ + "models": AppConfig( + models=["tests.testmodels"], + default_connection="default", + ) + }, + ) + ) + assert "models" in Tortoise.apps + assert connections.get("default") is not None + + +@pytest.mark.asyncio +async def test_empty_modules_init(): + await _reset_tortoise() + with pytest.warns(RuntimeWarning, match='Module "tests.model_setup" has no models'): await Tortoise.init( { "connections": { @@ -27,392 +98,416 @@ async def test_basic_init(self): } }, "apps": { - "models": {"models": ["tests.testmodels"], "default_connection": "default"} + "models": {"models": ["tests.model_setup"], "default_connection": "default"} }, } ) - self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(connections.get("default")) - async def test_dataclass_init(self): + +@pytest.mark.asyncio +async def test_dup1_init(): + await _reset_tortoise() + with pytest.raises( + ConfigurationError, match='backward relation "events" duplicates in model Tournament' + ): await Tortoise.init( - config=TortoiseConfig( - connections={ - "default": ConnectionConfig( - engine="tortoise.backends.sqlite", - credentials={"file_path": ":memory:"}, - ) + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } }, - apps={ - "models": AppConfig( - models=["tests.testmodels"], - default_connection="default", - ) + "apps": { + "models": { + "models": ["tests.model_setup.models_dup1"], + "default_connection": "default", + } }, - ) + } ) - self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(connections.get("default")) - - async def test_empty_modules_init(self): - with self.assertWarnsRegex(RuntimeWarning, 'Module "tests.model_setup" has no models'): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": {"models": ["tests.model_setup"], "default_connection": "default"} - }, - } - ) - - async def test_dup1_init(self): - with self.assertRaisesRegex( - ConfigurationError, 'backward relation "events" duplicates in model Tournament' - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": { - "models": ["tests.model_setup.models_dup1"], - "default_connection": "default", - } - }, - } - ) - - async def test_dup2_init(self): - with self.assertRaisesRegex( - ConfigurationError, 'backward relation "events" duplicates in model Team' - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": { - "models": ["tests.model_setup.models_dup2"], - "default_connection": "default", - } - }, - } - ) - - async def test_dup3_init(self): - with self.assertRaisesRegex( - ConfigurationError, 'backward relation "event" duplicates in model Tournament' - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": { - "models": ["tests.model_setup.models_dup3"], - "default_connection": "default", - } - }, - } - ) - - async def test_generated_nonint(self): - with self.assertRaisesRegex( - ConfigurationError, "Field 'val' \\(CharField\\) can't be DB-generated" - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": { - "models": ["tests.model_setup.model_generated_nonint"], - "default_connection": "default", - } - }, - } - ) - - async def test_multiple_pk(self): - with self.assertRaisesRegex( - ConfigurationError, - "Can't create model Tournament with two primary keys, only single primary key is supported", - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": { - "models": ["tests.model_setup.model_multiple_pk"], - "default_connection": "default", - } - }, - } - ) - - async def test_nonpk_id(self): - with self.assertRaisesRegex( - ConfigurationError, - "Can't create model Tournament without explicit primary key if" - " field 'id' already present", - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": { - "models": ["tests.model_setup.model_nonpk_id"], - "default_connection": "default", - } - }, - } - ) - - async def test_unknown_connection(self): - with self.assertRaisesRegex( - ConfigurationError, - "Unable to get db settings for alias 'fioop'. Please " - "check if the config dict contains this alias and try again", - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": {"models": ["tests.testmodels"], "default_connection": "fioop"} - }, - } - ) - async def test_init_connections_false(self): - config = { + +@pytest.mark.asyncio +async def test_dup2_init(): + await _reset_tortoise() + with pytest.raises( + ConfigurationError, match='backward relation "events" duplicates in model Team' + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": { + "models": ["tests.model_setup.models_dup2"], + "default_connection": "default", + } + }, + } + ) + + +@pytest.mark.asyncio +async def test_dup3_init(): + await _reset_tortoise() + with pytest.raises( + ConfigurationError, match='backward relation "event" duplicates in model Tournament' + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": { + "models": ["tests.model_setup.models_dup3"], + "default_connection": "default", + } + }, + } + ) + + +@pytest.mark.asyncio +async def test_generated_nonint(): + await _reset_tortoise() + with pytest.raises( + ConfigurationError, match="Field 'val' \\(CharField\\) can't be DB-generated" + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": { + "models": ["tests.model_setup.model_generated_nonint"], + "default_connection": "default", + } + }, + } + ) + + +@pytest.mark.asyncio +async def test_multiple_pk(): + await _reset_tortoise() + with pytest.raises( + ConfigurationError, + match="Can't create model Tournament with two primary keys, only single primary key is supported", + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": { + "models": ["tests.model_setup.model_multiple_pk"], + "default_connection": "default", + } + }, + } + ) + + +@pytest.mark.asyncio +async def test_nonpk_id(): + await _reset_tortoise() + with pytest.raises( + ConfigurationError, + match="Can't create model Tournament without explicit primary key if" + " field 'id' already present", + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": { + "models": ["tests.model_setup.model_nonpk_id"], + "default_connection": "default", + } + }, + } + ) + + +@pytest.mark.asyncio +async def test_unknown_connection(): + await _reset_tortoise() + with pytest.raises( + ConfigurationError, + match='App "models" refers to unknown connection "fioop"', + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": {"models": {"models": ["tests.testmodels"], "default_connection": "fioop"}}, + } + ) + + +@pytest.mark.asyncio +async def test_init_connections_false(): + await _reset_tortoise() + config = { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": {"models": {"models": ["tests.testmodels"], "default_connection": "default"}}, + } + with ( + patch("tortoise.connections._init") as mocked_init, + patch("tortoise.connections.get") as mocked_get, + ): + await Tortoise.init(config=config, init_connections=False) + mocked_init.assert_not_called() + mocked_get.assert_not_called() + assert "models" in Tortoise.apps + assert connections.db_config == config["connections"] + + +@pytest.mark.asyncio +async def test_init_connections_false_with_create_db(): + await _reset_tortoise() + config = { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": {"models": {"models": ["tests.testmodels"], "default_connection": "default"}}, + } + with pytest.raises( + ConfigurationError, match="init_connections=False cannot be used with _create_db=True" + ): + await Tortoise.init(config=config, _create_db=True, init_connections=False) + + +@pytest.mark.asyncio +async def test_url_without_modules(): + await _reset_tortoise() + with pytest.raises( + ConfigurationError, + match="Must provide either 'config', 'config_file', or both 'db_url' and 'modules'", + ): + await Tortoise.init(db_url=f"sqlite://{':memory:'}") + + +@pytest.mark.asyncio +async def test_default_connection_init(): + await _reset_tortoise() + await Tortoise.init( + { "connections": { "default": { "engine": "tortoise.backends.sqlite", "credentials": {"file_path": ":memory:"}, } }, - "apps": {"models": {"models": ["tests.testmodels"], "default_connection": "default"}}, + "apps": {"models": {"models": ["tests.testmodels"]}}, } - with ( - patch("tortoise.connections._init") as mocked_init, - patch("tortoise.connections.get") as mocked_get, - ): - await Tortoise.init(config=config, init_connections=False) - mocked_init.assert_not_called() - mocked_get.assert_not_called() - self.assertIn("models", Tortoise.apps) - self.assertEqual(connections.db_config, config["connections"]) - - async def test_init_connections_false_with_create_db(self): - config = { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, + ) + assert "models" in Tortoise.apps + assert connections.get("default") is not None + + +@pytest.mark.asyncio +async def test_db_url_init(): + await _reset_tortoise() + await Tortoise.init( + { + "connections": {"default": f"sqlite://{':memory:'}"}, "apps": {"models": {"models": ["tests.testmodels"], "default_connection": "default"}}, } - with self.assertRaisesRegex( - ConfigurationError, "init_connections=False cannot be used with _create_db=True" - ): - await Tortoise.init(config=config, _create_db=True, init_connections=False) - - async def test_url_without_modules(self): - with self.assertRaisesRegex( - ConfigurationError, 'You must specify "db_url" and "modules" together' - ): - await Tortoise.init(db_url=f"sqlite://{':memory:'}") - - async def test_default_connection_init(self): + ) + assert "models" in Tortoise.apps + assert connections.get("default") is not None + + +@pytest.mark.asyncio +async def test_shorthand_init(): + await _reset_tortoise() + await Tortoise.init(db_url=f"sqlite://{':memory:'}", modules={"models": ["tests.testmodels"]}) + assert "models" in Tortoise.apps + assert connections.get("default") is not None + + +@pytest.mark.asyncio +async def test_init_wrong_connection_engine(): + await _reset_tortoise() + with pytest.raises(ImportError, match="tortoise.backends.test"): await Tortoise.init( { "connections": { "default": { - "engine": "tortoise.backends.sqlite", + "engine": "tortoise.backends.test", "credentials": {"file_path": ":memory:"}, } }, - "apps": {"models": {"models": ["tests.testmodels"]}}, + "apps": { + "models": {"models": ["tests.testmodels"], "default_connection": "default"} + }, } ) - self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(connections.get("default")) - async def test_db_url_init(self): + +@pytest.mark.asyncio +async def test_init_wrong_connection_engine_2(): + await _reset_tortoise() + with pytest.raises( + ConfigurationError, + match='Backend for engine "tortoise.backends" does not implement db client', + ): await Tortoise.init( { - "connections": {"default": f"sqlite://{':memory:'}"}, + "connections": { + "default": { + "engine": "tortoise.backends", + "credentials": {"file_path": ":memory:"}, + } + }, "apps": { "models": {"models": ["tests.testmodels"], "default_connection": "default"} }, } ) - self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(connections.get("default")) - async def test_shorthand_init(self): + +@pytest.mark.asyncio +async def test_init_no_connections(): + await _reset_tortoise() + with pytest.raises(ConfigurationError, match='Config must define "connections" section'): await Tortoise.init( - db_url=f"sqlite://{':memory:'}", modules={"models": ["tests.testmodels"]} + {"apps": {"models": {"models": ["tests.testmodels"], "default_connection": "default"}}} ) - self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(connections.get("default")) - - async def test_init_wrong_connection_engine(self): - with self.assertRaisesRegex(ImportError, "tortoise.backends.test"): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.test", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": {"models": ["tests.testmodels"], "default_connection": "default"} - }, - } - ) - - async def test_init_wrong_connection_engine_2(self): - with self.assertRaisesRegex( - ConfigurationError, - 'Backend for engine "tortoise.backends" does not implement db client', - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": {"models": ["tests.testmodels"], "default_connection": "default"} - }, - } - ) - - async def test_init_no_connections(self): - with self.assertRaisesRegex(ConfigurationError, 'Config must define "connections" section'): - await Tortoise.init( - { - "apps": { - "models": {"models": ["tests.testmodels"], "default_connection": "default"} + + +@pytest.mark.asyncio +async def test_init_no_apps(): + await _reset_tortoise() + with pytest.raises(ConfigurationError, match='Config must define "apps" section'): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, } } - ) - - async def test_init_no_apps(self): - with self.assertRaisesRegex(ConfigurationError, 'Config must define "apps" section'): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } + } + ) + + +@pytest.mark.asyncio +async def test_init_config_and_config_file(): + await _reset_tortoise() + with pytest.raises( + ConfigurationError, match='You should init either from "config", "config_file" or "db_url"' + ): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, } - } - ) - - async def test_init_config_and_config_file(self): - with self.assertRaisesRegex( - ConfigurationError, 'You should init either from "config", "config_file" or "db_url"' - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": {"models": ["tests.testmodels"], "default_connection": "default"} - }, }, - config_file="file.json", - ) - - async def test_init_config_file_wrong_extension(self): - with self.assertRaisesRegex( - ConfigurationError, "Unknown config extension .ini, only .yml and .json are supported" - ): - await Tortoise.init(config_file="config.ini") - - @test.skipIf(os.name == "nt", "path issue on Windows") - async def test_init_json_file(self): - await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.json") - self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(connections.get("default")) - - @test.skipIf(os.name == "nt", "path issue on Windows") - async def test_init_yaml_file(self): - await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.yaml") - self.assertIn("models", Tortoise.apps) - self.assertIsNotNone(connections.get("default")) - - async def test_generate_schema_without_init(self): - with self.assertRaisesRegex( - ConfigurationError, r"You have to call \.init\(\) first before generating schemas" - ): - await Tortoise.generate_schemas() - - async def test_drop_databases_without_init(self): - with self.assertRaisesRegex( - ConfigurationError, r"You have to call \.init\(\) first before deleting schemas" - ): - await Tortoise._drop_databases() - - async def test_bad_models(self): - with self.assertRaisesRegex(ConfigurationError, 'Module "tests.testmodels2" not found'): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": { - "models": {"models": ["tests.testmodels2"], "default_connection": "default"} - }, - } - ) + "apps": { + "models": {"models": ["tests.testmodels"], "default_connection": "default"} + }, + }, + config_file="file.json", + ) + + +@pytest.mark.asyncio +async def test_init_config_file_wrong_extension(): + await _reset_tortoise() + with pytest.raises( + ConfigurationError, match="Unknown config extension .ini, only .yml and .json are supported" + ): + await Tortoise.init(config_file="config.ini") + + +@pytest.mark.skipif(os.name == "nt", reason="path issue on Windows") +@pytest.mark.asyncio +async def test_init_json_file(): + await _reset_tortoise() + await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.json") + assert "models" in Tortoise.apps + assert connections.get("default") is not None + + +@pytest.mark.skipif(os.name == "nt", reason="path issue on Windows") +@pytest.mark.asyncio +async def test_init_yaml_file(): + await _reset_tortoise() + await Tortoise.init(config_file=os.path.dirname(__file__) + "/init.yaml") + assert "models" in Tortoise.apps + assert connections.get("default") is not None + + +@pytest.mark.asyncio +async def test_generate_schema_without_init(): + await _reset_tortoise() + with pytest.raises( + ConfigurationError, match=r"You have to call \.init\(\) first before generating schemas" + ): + await Tortoise.generate_schemas() + + +@pytest.mark.asyncio +async def test_drop_databases_without_init(): + await _reset_tortoise() + with pytest.raises( + ConfigurationError, match=r"You have to call \.init\(\) first before deleting schemas" + ): + await Tortoise._drop_databases() + + +@pytest.mark.asyncio +async def test_bad_models(): + await _reset_tortoise() + with pytest.raises(ConfigurationError, match='Module "tests.testmodels2" not found'): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": { + "models": {"models": ["tests.testmodels2"], "default_connection": "default"} + }, + } + ) diff --git a/tests/schema/test_generate_schema.py b/tests/schema/test_generate_schema.py index c62cb3735..fb46093b9 100644 --- a/tests/schema/test_generate_schema.py +++ b/tests/schema/test_generate_schema.py @@ -1,15 +1,23 @@ # pylint: disable=C0301 +import os import re from unittest.mock import MagicMock, patch +import pytest + from tortoise import Tortoise, connections -from tortoise.contrib import test +from tortoise.backends.base.config_generator import generate_config +from tortoise.context import TortoiseContext, get_current_context from tortoise.exceptions import ConfigurationError from tortoise.utils import get_schema_sql +# Save original classproperties before any test can shadow them +_original_apps_prop = Tortoise.__dict__["apps"] +_original_inited_prop = Tortoise.__dict__["_inited"] + -class TestGenerateSchema(test.SimpleTestCase): - safe_schema_sql = """ +# Safe schema SQL expected for SQLite +SAFE_SCHEMA_SQL = """ CREATE TABLE IF NOT EXISTS "company" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, "name" TEXT NOT NULL, @@ -92,152 +100,277 @@ class TestGenerateSchema(test.SimpleTestCase): CREATE UNIQUE INDEX IF NOT EXISTS "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id"); """.strip() - async def asyncSetUp(self): - await super().asyncSetUp() - try: - Tortoise.apps = None - Tortoise._inited = False - except ConfigurationError: - pass - Tortoise._inited = False - self.sqls = [] - self.post_sqls = [] - self.engine = test.getDBConfig(app_label="models", modules=[])["connections"]["models"][ - "engine" - ] - - async def asyncTearDown(self) -> None: - await Tortoise._reset_apps() - await super().asyncTearDown() - - async def init_for(self, module: str, safe=False) -> None: - with patch( - "tortoise.backends.sqlite.client.SqliteClient.create_connection", new=MagicMock() - ): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.sqlite", - "credentials": {"file_path": ":memory:"}, - } - }, - "apps": {"models": {"models": [module], "default_connection": "default"}}, - } - ) - self.sqls = get_schema_sql(connections.get("default"), safe).split(";\n") - - def get_sql(self, text: str) -> str: - return re.sub(r"[ \t\n\r]+", " ", " ".join([sql for sql in self.sqls if text in sql])) - - async def test_noid(self): - await self.init_for("tests.testmodels") - sql = self.get_sql('"noid"') - self.assertIn('"name" VARCHAR(255)', sql) - self.assertIn('"id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL', sql) - - async def test_minrelation(self): - await self.init_for("tests.testmodels") - sql = self.get_sql('"minrelation"') - self.assertIn( - '"tournament_id" SMALLINT NOT NULL REFERENCES "tournament" ("id") ON DELETE CASCADE', - sql, + +async def _reset_tortoise(): + """Helper to reset Tortoise state before each test. + + Note: We MUST NOT set Tortoise.apps = None or Tortoise._inited = False + because these are classproperties and setting them shadows the property + with a class attribute, breaking future access. + """ + # Restore original classproperties if they were shadowed + if not isinstance(Tortoise.__dict__.get("apps"), type(_original_apps_prop)): + type.__setattr__(Tortoise, "apps", _original_apps_prop) + if not isinstance(Tortoise.__dict__.get("_inited"), type(_original_inited_prop)): + type.__setattr__(Tortoise, "_inited", _original_inited_prop) + + # Get the current context and properly reset it + ctx = get_current_context() + if ctx is not None: + # Clear db_config first to prevent close_all from trying to import bad backends + if ctx._connections is not None: + # Clear storage without closing (to avoid importing bad backends) + ctx._connections._storage.clear() + ctx._connections._db_config = None + ctx._connections = None + ctx._apps = None + ctx._inited = False + ctx._default_connection = None + else: + # No context exists - create one for the test + ctx = TortoiseContext() + ctx.__enter__() + + +async def _teardown_tortoise(): + """Helper to teardown Tortoise state after each test.""" + await Tortoise._reset_apps() + + +def _get_engine(): + """Get the current test engine.""" + db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") + config = generate_config(db_url, app_modules={"models": []}, connection_label="models") + return config["connections"]["models"]["engine"] + + +def _get_sql(sqls: list[str], text: str) -> str: + """Get SQL statement containing the given text.""" + return re.sub(r"[ \t\n\r]+", " ", " ".join([sql for sql in sqls if text in sql])) + + +# ============================================================================ +# SQLite Tests +# ============================================================================ + + +async def _init_for_sqlite(module: str, safe: bool = False) -> list[str]: + """Initialize Tortoise for SQLite and return SQL statements.""" + with patch("tortoise.backends.sqlite.client.SqliteClient.create_connection", new=MagicMock()): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.sqlite", + "credentials": {"file_path": ":memory:"}, + } + }, + "apps": {"models": {"models": [module], "default_connection": "default"}}, + } + ) + return get_schema_sql(connections.get("default"), safe).split(";\n") + + +@pytest.mark.asyncio +async def test_noid(): + await _reset_tortoise() + try: + sqls = await _init_for_sqlite("tests.testmodels") + sql = _get_sql(sqls, '"noid"') + assert '"name" VARCHAR(255)' in sql + assert '"id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL' in sql + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_minrelation(): + await _reset_tortoise() + try: + sqls = await _init_for_sqlite("tests.testmodels") + sql = _get_sql(sqls, '"minrelation"') + assert ( + '"tournament_id" SMALLINT NOT NULL REFERENCES "tournament" ("id") ON DELETE CASCADE' + in sql ) - self.assertNotIn("participants", sql) + assert "participants" not in sql - sql = self.get_sql('"minrelation_team"') - self.assertIn( - '"minrelation_id" INT NOT NULL REFERENCES "minrelation" ("id") ON DELETE CASCADE', sql + sql = _get_sql(sqls, '"minrelation_team"') + assert ( + '"minrelation_id" INT NOT NULL REFERENCES "minrelation" ("id") ON DELETE CASCADE' in sql ) - self.assertIn('"team_id" INT NOT NULL REFERENCES "team" ("id") ON DELETE CASCADE', sql) - - async def test_safe_generation(self): - """Assert that the IF NOT EXISTS clause is included when safely generating schema.""" - await self.init_for("tests.testmodels", True) - sql = self.get_sql("") - self.assertIn("IF NOT EXISTS", sql) - - async def test_unsafe_generation(self): - """Assert that the IF NOT EXISTS clause is not included when generating schema.""" - await self.init_for("tests.testmodels", False) - sql = self.get_sql("") - self.assertNotIn("IF NOT EXISTS", sql) - - async def test_cyclic(self): - with self.assertRaisesRegex( - ConfigurationError, "Can't create schema due to cyclic fk references" + assert '"team_id" INT NOT NULL REFERENCES "team" ("id") ON DELETE CASCADE' in sql + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_safe_generation(): + """Assert that the IF NOT EXISTS clause is included when safely generating schema.""" + await _reset_tortoise() + try: + sqls = await _init_for_sqlite("tests.testmodels", True) + sql = _get_sql(sqls, "") + assert "IF NOT EXISTS" in sql + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_unsafe_generation(): + """Assert that the IF NOT EXISTS clause is not included when generating schema.""" + await _reset_tortoise() + try: + sqls = await _init_for_sqlite("tests.testmodels", False) + sql = _get_sql(sqls, "") + assert "IF NOT EXISTS" not in sql + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_cyclic(): + await _reset_tortoise() + try: + with pytest.raises( + ConfigurationError, match="Can't create schema due to cyclic fk references" ): - await self.init_for("tests.schema.models_cyclic") + await _init_for_sqlite("tests.schema.models_cyclic") + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_create_index(): + await _reset_tortoise() + try: + sqls = await _init_for_sqlite("tests.testmodels") + sql = _get_sql(sqls, "CREATE INDEX") + assert re.search(r"idx_tournament_created_\w+", sql) is not None + finally: + await _teardown_tortoise() - async def test_create_index(self): - await self.init_for("tests.testmodels") - sql = self.get_sql("CREATE INDEX") - self.assertIsNotNone(re.search(r"idx_tournament_created_\w+", sql)) - async def test_create_index_with_custom_name(self): - await self.init_for("tests.testmodels") - sql = self.get_sql("f3") - self.assertIn("model_with_indexes__f3", sql) +@pytest.mark.asyncio +async def test_create_index_with_custom_name(): + await _reset_tortoise() + try: + sqls = await _init_for_sqlite("tests.testmodels") + sql = _get_sql(sqls, "f3") + assert "model_with_indexes__f3" in sql + finally: + await _teardown_tortoise() - async def test_fk_bad_model_name(self): - with self.assertRaisesRegex( - ConfigurationError, 'ForeignKeyField accepts model name in format "app.Model"' + +@pytest.mark.asyncio +async def test_fk_bad_model_name(): + await _reset_tortoise() + try: + with pytest.raises( + ConfigurationError, match='ForeignKeyField accepts model name in format "app.Model"' ): - await self.init_for("tests.schema.models_fk_1") + await _init_for_sqlite("tests.schema.models_fk_1") + finally: + await _teardown_tortoise() + - async def test_fk_bad_on_delete(self): - with self.assertRaisesRegex( +@pytest.mark.asyncio +async def test_fk_bad_on_delete(): + await _reset_tortoise() + try: + with pytest.raises( ConfigurationError, - "on_delete can only be CASCADE, RESTRICT, SET_NULL, SET_DEFAULT or NO_ACTION", + match="on_delete can only be CASCADE, RESTRICT, SET_NULL, SET_DEFAULT or NO_ACTION", ): - await self.init_for("tests.schema.models_fk_2") + await _init_for_sqlite("tests.schema.models_fk_2") + finally: + await _teardown_tortoise() - async def test_fk_bad_null(self): - with self.assertRaisesRegex( - ConfigurationError, "If on_delete is SET_NULL, then field must have null=True set" + +@pytest.mark.asyncio +async def test_fk_bad_null(): + await _reset_tortoise() + try: + with pytest.raises( + ConfigurationError, match="If on_delete is SET_NULL, then field must have null=True set" ): - await self.init_for("tests.schema.models_fk_3") + await _init_for_sqlite("tests.schema.models_fk_3") + finally: + await _teardown_tortoise() + - async def test_o2o_bad_on_delete(self): - with self.assertRaisesRegex( +@pytest.mark.asyncio +async def test_o2o_bad_on_delete(): + await _reset_tortoise() + try: + with pytest.raises( ConfigurationError, - "on_delete can only be CASCADE, RESTRICT, SET_NULL, SET_DEFAULT or NO_ACTION", + match="on_delete can only be CASCADE, RESTRICT, SET_NULL, SET_DEFAULT or NO_ACTION", ): - await self.init_for("tests.schema.models_o2o_2") + await _init_for_sqlite("tests.schema.models_o2o_2") + finally: + await _teardown_tortoise() - async def test_o2o_bad_null(self): - with self.assertRaisesRegex( - ConfigurationError, "If on_delete is SET_NULL, then field must have null=True set" + +@pytest.mark.asyncio +async def test_o2o_bad_null(): + await _reset_tortoise() + try: + with pytest.raises( + ConfigurationError, match="If on_delete is SET_NULL, then field must have null=True set" ): - await self.init_for("tests.schema.models_o2o_3") + await _init_for_sqlite("tests.schema.models_o2o_3") + finally: + await _teardown_tortoise() + - async def test_m2m_bad_model_name(self): - with self.assertRaisesRegex( - ConfigurationError, 'ManyToManyField accepts model name in format "app.Model"' +@pytest.mark.asyncio +async def test_m2m_bad_model_name(): + await _reset_tortoise() + try: + with pytest.raises( + ConfigurationError, match='ManyToManyField accepts model name in format "app.Model"' ): - await self.init_for("tests.schema.models_m2m_1") - - async def test_multi_m2m_fields_in_a_model(self): - await self.init_for("tests.schema.models_m2m_2") - sql = self.get_sql("CASCADE") - self.assertNotRegex(sql, r'REFERENCES [`"]three_one[`"]') - self.assertNotRegex(sql, r'REFERENCES [`"]three_two[`"]') - self.assertRegex(sql, r'REFERENCES [`"](one|two|three)[`"]') - - async def test_table_and_row_comment_generation(self): - await self.init_for("tests.testmodels") - sql = self.get_sql("comments") - self.assertRegex(sql, r".*\/\* Upvotes done on the comment.*\*\/") - self.assertRegex(sql, r".*\\n.*") - self.assertIn("\\/", sql) - - async def test_schema_no_db_constraint(self): - self.maxDiff = None - await self.init_for("tests.schema.models_no_db_constraint") + await _init_for_sqlite("tests.schema.models_m2m_1") + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_multi_m2m_fields_in_a_model(): + await _reset_tortoise() + try: + sqls = await _init_for_sqlite("tests.schema.models_m2m_2") + sql = _get_sql(sqls, "CASCADE") + assert not re.search(r'REFERENCES [`"]three_one[`"]', sql) + assert not re.search(r'REFERENCES [`"]three_two[`"]', sql) + assert re.search(r'REFERENCES [`"](one|two|three)[`"]', sql) + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_table_and_row_comment_generation(): + await _reset_tortoise() + try: + sqls = await _init_for_sqlite("tests.testmodels") + sql = _get_sql(sqls, "comments") + assert re.search(r".*\/\* Upvotes done on the comment.*\*\/", sql) + assert re.search(r".*\\n.*", sql) + assert "\\/" in sql + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_schema_no_db_constraint(): + await _reset_tortoise() + try: + await _init_for_sqlite("tests.schema.models_no_db_constraint") sql = get_schema_sql(connections.get("default"), safe=False) - self.assertEqual( - sql.strip(), - r"""CREATE TABLE "team" ( + assert ( + sql.strip() + == r"""CREATE TABLE "team" ( "name" VARCHAR(50) NOT NULL PRIMARY KEY /* The TEAM name (and PK) */, "key" INT NOT NULL, "manager_id" VARCHAR(50) @@ -270,16 +403,21 @@ async def test_schema_no_db_constraint(self): "event_id" BIGINT NOT NULL, "team_id" VARCHAR(50) NOT NULL ) /* How participants relate */; -CREATE UNIQUE INDEX "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id");""", +CREATE UNIQUE INDEX "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id");""" ) + finally: + await _teardown_tortoise() - async def test_schema(self): - self.maxDiff = None - await self.init_for("tests.schema.models_schema_create") + +@pytest.mark.asyncio +async def test_schema(): + await _reset_tortoise() + try: + await _init_for_sqlite("tests.schema.models_schema_create") sql = get_schema_sql(connections.get("default"), safe=False) - self.assertEqual( - sql.strip(), - """ + assert ( + sql.strip() + == """ CREATE TABLE "company" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, "name" TEXT NOT NULL, @@ -360,22 +498,32 @@ async def test_schema(self): "team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE SET NULL ) /* How participants relate */; CREATE UNIQUE INDEX "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id"); -""".strip(), +""".strip() ) + finally: + await _teardown_tortoise() - async def test_schema_safe(self): - self.maxDiff = None - await self.init_for("tests.schema.models_schema_create") + +@pytest.mark.asyncio +async def test_schema_safe(): + await _reset_tortoise() + try: + await _init_for_sqlite("tests.schema.models_schema_create") sql = get_schema_sql(connections.get("default"), safe=True) - self.assertEqual(sql.strip(), self.safe_schema_sql) + assert sql.strip() == SAFE_SCHEMA_SQL + finally: + await _teardown_tortoise() + - async def test_m2m_no_auto_create(self): - self.maxDiff = None - await self.init_for("tests.schema.models_no_auto_create_m2m") +@pytest.mark.asyncio +async def test_m2m_no_auto_create(): + await _reset_tortoise() + try: + await _init_for_sqlite("tests.schema.models_no_auto_create_m2m") sql = get_schema_sql(connections.get("default"), safe=False) - self.assertEqual( - sql.strip(), - r"""CREATE TABLE "team" ( + assert ( + sql.strip() + == r"""CREATE TABLE "team" ( "name" VARCHAR(50) NOT NULL PRIMARY KEY /* The TEAM name (and PK) */, "key" INT NOT NULL, "manager_id" VARCHAR(50) REFERENCES "team" ("name") ON DELETE CASCADE @@ -411,80 +559,115 @@ async def test_m2m_no_auto_create(self): "team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE ); CREATE UNIQUE INDEX "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id"); -""".strip(), +""".strip() ) + finally: + await _teardown_tortoise() -class TestGenerateSchemaMySQL(TestGenerateSchema): - async def init_for(self, module: str, safe=False) -> None: - try: - with patch("aiomysql.create_pool", new=MagicMock()): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.mysql", - "credentials": { - "database": "test", - "host": "127.0.0.1", - "password": "foomip", - "port": 3306, - "user": "root", - "connect_timeout": 1.5, - "charset": "utf8mb4", - }, - } - }, - "apps": {"models": {"models": [module], "default_connection": "default"}}, - } - ) - self.sqls = get_schema_sql(connections.get("default"), safe).split("; ") - except ImportError: - raise test.SkipTest("aiomysql not installed") - - async def test_noid(self): - await self.init_for("tests.testmodels") - sql = self.get_sql("`noid`") - self.assertIn("`name` VARCHAR(255)", sql) - self.assertIn("`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT", sql) - - async def test_create_index(self): - await self.init_for("tests.testmodels") - sql = self.get_sql("KEY") - self.assertIsNotNone(re.search(r"idx_tournament_created_\w+", sql)) - - async def test_minrelation(self): - await self.init_for("tests.testmodels") - sql = self.get_sql("`minrelation`") - self.assertIn("`tournament_id` SMALLINT NOT NULL,", sql) - self.assertIn( - "FOREIGN KEY (`tournament_id`) REFERENCES `tournament` (`id`) ON DELETE CASCADE", sql +# ============================================================================ +# MySQL Tests +# ============================================================================ + + +async def _init_for_mysql(module: str, safe: bool = False) -> list[str]: + """Initialize Tortoise for MySQL and return SQL statements.""" + try: + with patch("aiomysql.create_pool", new=MagicMock()): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.mysql", + "credentials": { + "database": "test", + "host": "127.0.0.1", + "password": "foomip", + "port": 3306, + "user": "root", + "connect_timeout": 1.5, + "charset": "utf8mb4", + }, + } + }, + "apps": {"models": {"models": [module], "default_connection": "default"}}, + } + ) + return get_schema_sql(connections.get("default"), safe).split("; ") + except ImportError: + pytest.skip("aiomysql not installed") + + +@pytest.mark.asyncio +async def test_mysql_noid(): + await _reset_tortoise() + try: + sqls = await _init_for_mysql("tests.testmodels") + sql = _get_sql(sqls, "`noid`") + assert "`name` VARCHAR(255)" in sql + assert "`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT" in sql + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_mysql_create_index(): + await _reset_tortoise() + try: + sqls = await _init_for_mysql("tests.testmodels") + sql = _get_sql(sqls, "KEY") + assert re.search(r"idx_tournament_created_\w+", sql) is not None + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_mysql_minrelation(): + await _reset_tortoise() + try: + sqls = await _init_for_mysql("tests.testmodels") + sql = _get_sql(sqls, "`minrelation`") + assert "`tournament_id` SMALLINT NOT NULL," in sql + assert ( + "FOREIGN KEY (`tournament_id`) REFERENCES `tournament` (`id`) ON DELETE CASCADE" in sql ) - self.assertNotIn("participants", sql) + assert "participants" not in sql - sql = self.get_sql("`minrelation_team`") - self.assertIn("`minrelation_id` INT NOT NULL", sql) - self.assertIn( - "FOREIGN KEY (`minrelation_id`) REFERENCES `minrelation` (`id`) ON DELETE CASCADE", sql + sql = _get_sql(sqls, "`minrelation_team`") + assert "`minrelation_id` INT NOT NULL" in sql + assert ( + "FOREIGN KEY (`minrelation_id`) REFERENCES `minrelation` (`id`) ON DELETE CASCADE" + in sql ) - self.assertIn("`team_id` INT NOT NULL", sql) - self.assertIn("FOREIGN KEY (`team_id`) REFERENCES `team` (`id`) ON DELETE CASCADE", sql) - - async def test_table_and_row_comment_generation(self): - await self.init_for("tests.testmodels") - sql = self.get_sql("comments") - self.assertIn("COMMENT='Test Table comment'", sql) - self.assertIn("COMMENT 'This column acts as it\\'s own comment'", sql) - self.assertRegex(sql, r".*\\n.*") - self.assertRegex(sql, r".*it\\'s.*") - - async def test_schema_no_db_constraint(self): - self.maxDiff = None - await self.init_for("tests.schema.models_no_db_constraint") + assert "`team_id` INT NOT NULL" in sql + assert "FOREIGN KEY (`team_id`) REFERENCES `team` (`id`) ON DELETE CASCADE" in sql + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_mysql_table_and_row_comment_generation(): + await _reset_tortoise() + try: + sqls = await _init_for_mysql("tests.testmodels") + sql = _get_sql(sqls, "comments") + assert "COMMENT='Test Table comment'" in sql + assert "COMMENT 'This column acts as it\\'s own comment'" in sql + assert re.search(r".*\\n.*", sql) + assert re.search(r".*it\\'s.*", sql) + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_mysql_schema_no_db_constraint(): + await _reset_tortoise() + try: + await _init_for_mysql("tests.schema.models_no_db_constraint") sql = get_schema_sql(connections.get("default"), safe=False) - self.assertEqual( - sql.strip(), - r"""CREATE TABLE `team` ( + assert ( + sql.strip() + == r"""CREATE TABLE `team` ( `name` VARCHAR(50) NOT NULL PRIMARY KEY COMMENT 'The TEAM name (and PK)', `key` INT NOT NULL, `manager_id` VARCHAR(50), @@ -517,16 +700,21 @@ async def test_schema_no_db_constraint(self): `event_id` BIGINT NOT NULL, `team_id` VARCHAR(50) NOT NULL, UNIQUE KEY `uidx_teamevents_event_i_664dbc` (`event_id`, `team_id`) -) CHARACTER SET utf8mb4 COMMENT='How participants relate';""", +) CHARACTER SET utf8mb4 COMMENT='How participants relate';""" ) + finally: + await _teardown_tortoise() - async def test_schema(self): - self.maxDiff = None - await self.init_for("tests.schema.models_schema_create") + +@pytest.mark.asyncio +async def test_mysql_schema(): + await _reset_tortoise() + try: + await _init_for_mysql("tests.schema.models_schema_create") sql = get_schema_sql(connections.get("default"), safe=False) - self.assertEqual( - sql.strip(), - """ + assert ( + sql.strip() + == """ CREATE TABLE `company` ( `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT, `name` LONGTEXT NOT NULL, @@ -619,20 +807,25 @@ async def test_schema(self): FOREIGN KEY (`team_id`) REFERENCES `team` (`name`) ON DELETE SET NULL, UNIQUE KEY `uidx_teamevents_event_i_664dbc` (`event_id`, `team_id`) ) CHARACTER SET utf8mb4 COMMENT='How participants relate'; -""".strip(), +""".strip() ) + finally: + await _teardown_tortoise() - async def test_schema_safe(self): - self.maxDiff = None - await self.init_for("tests.schema.models_schema_create") + +@pytest.mark.asyncio +async def test_mysql_schema_safe(): + await _reset_tortoise() + try: + await _init_for_mysql("tests.schema.models_schema_create") sql = get_schema_sql(connections.get("default"), safe=True).strip() - if sql == self.safe_schema_sql: + if sql == SAFE_SCHEMA_SQL: # Sometimes github action get different result from local machine(Ubuntu20) - self.assertEqual(sql, self.safe_schema_sql) + assert sql == SAFE_SCHEMA_SQL return - self.assertEqual( - sql, - """ + assert ( + sql + == """ CREATE TABLE IF NOT EXISTS `company` ( `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT, `name` LONGTEXT NOT NULL, @@ -725,44 +918,61 @@ async def test_schema_safe(self): FOREIGN KEY (`team_id`) REFERENCES `team` (`name`) ON DELETE SET NULL, UNIQUE KEY `uidx_teamevents_event_i_664dbc` (`event_id`, `team_id`) ) CHARACTER SET utf8mb4 COMMENT='How participants relate'; -""".strip(), +""".strip() ) + finally: + await _teardown_tortoise() - async def test_index_safe(self): - await self.init_for("tests.schema.models_mysql_index") + +@pytest.mark.asyncio +async def test_mysql_index_safe(): + await _reset_tortoise() + try: + await _init_for_mysql("tests.schema.models_mysql_index") sql = get_schema_sql(connections.get("default"), safe=True) - self.assertEqual( - sql, - """CREATE TABLE IF NOT EXISTS `index` ( + assert ( + sql + == """CREATE TABLE IF NOT EXISTS `index` ( `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT, `full_text` LONGTEXT NOT NULL, `geometry` GEOMETRY NOT NULL, FULLTEXT KEY `idx_index_full_te_3caba4` (`full_text`) WITH PARSER ngram, SPATIAL KEY `idx_index_geometr_0b4dfb` (`geometry`) -) CHARACTER SET utf8mb4;""", +) CHARACTER SET utf8mb4;""" ) + finally: + await _teardown_tortoise() + - async def test_index_unsafe(self): - await self.init_for("tests.schema.models_mysql_index") +@pytest.mark.asyncio +async def test_mysql_index_unsafe(): + await _reset_tortoise() + try: + await _init_for_mysql("tests.schema.models_mysql_index") sql = get_schema_sql(connections.get("default"), safe=False) - self.assertEqual( - sql, - """CREATE TABLE `index` ( + assert ( + sql + == """CREATE TABLE `index` ( `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT, `full_text` LONGTEXT NOT NULL, `geometry` GEOMETRY NOT NULL, FULLTEXT KEY `idx_index_full_te_3caba4` (`full_text`) WITH PARSER ngram, SPATIAL KEY `idx_index_geometr_0b4dfb` (`geometry`) -) CHARACTER SET utf8mb4;""", +) CHARACTER SET utf8mb4;""" ) + finally: + await _teardown_tortoise() - async def test_m2m_no_auto_create(self): - self.maxDiff = None - await self.init_for("tests.schema.models_no_auto_create_m2m") + +@pytest.mark.asyncio +async def test_mysql_m2m_no_auto_create(): + await _reset_tortoise() + try: + await _init_for_mysql("tests.schema.models_no_auto_create_m2m") sql = get_schema_sql(connections.get("default"), safe=False) - self.assertEqual( - sql.strip(), - r"""CREATE TABLE `team` ( + assert ( + sql.strip() + == r"""CREATE TABLE `team` ( `name` VARCHAR(50) NOT NULL PRIMARY KEY COMMENT 'The TEAM name (and PK)', `key` INT NOT NULL, `manager_id` VARCHAR(50), @@ -804,40 +1014,80 @@ async def test_m2m_no_auto_create(self): FOREIGN KEY (`team_id`) REFERENCES `team` (`name`) ON DELETE CASCADE, UNIQUE KEY `uidx_team_team_team_re_d994df` (`team_rel_id`, `team_id`) ) CHARACTER SET utf8mb4; -""".strip(), +""".strip() ) + finally: + await _teardown_tortoise() -class GenerateSchemaPostgresSQL(TestGenerateSchema): - async def init_for(self, module: str, safe=False) -> None: - raise test.SkipTest("This class is abstract") +# ============================================================================ +# PostgreSQL Tests (asyncpg) +# ============================================================================ - async def test_noid(self): - await self.init_for("tests.testmodels") - sql = self.get_sql('"noid"') - self.assertIn('"name" VARCHAR(255)', sql) - self.assertIn('"id" SERIAL NOT NULL PRIMARY KEY', sql) - async def test_table_and_row_comment_generation(self): - await self.init_for("tests.testmodels") - sql = self.get_sql("comments") - self.assertIn("COMMENT ON TABLE \"comments\" IS 'Test Table comment'", sql) - self.assertIn( +async def _init_for_asyncpg(module: str, safe: bool = False) -> list[str]: + """Initialize Tortoise for asyncpg and return SQL statements.""" + try: + with patch("asyncpg.create_pool", new=MagicMock()): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.asyncpg", + "credentials": { + "database": "test", + "host": "127.0.0.1", + "password": "foomip", + "port": 5432, + "user": "root", + }, + } + }, + "apps": {"models": {"models": [module], "default_connection": "default"}}, + } + ) + return get_schema_sql(connections.get("default"), safe).split("; ") + except ImportError: + pytest.skip("asyncpg not installed") + + +@pytest.mark.asyncio +async def test_asyncpg_noid(): + await _reset_tortoise() + try: + sqls = await _init_for_asyncpg("tests.testmodels") + sql = _get_sql(sqls, '"noid"') + assert '"name" VARCHAR(255)' in sql + assert '"id" SERIAL NOT NULL PRIMARY KEY' in sql + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_asyncpg_table_and_row_comment_generation(): + await _reset_tortoise() + try: + sqls = await _init_for_asyncpg("tests.testmodels") + sql = _get_sql(sqls, "comments") + assert "COMMENT ON TABLE \"comments\" IS 'Test Table comment'" in sql + assert ( 'COMMENT ON COLUMN "comments"."escaped_comment_field" IS ' - "'This column acts as it''s own comment'", - sql, - ) - self.assertIn( - 'COMMENT ON COLUMN "comments"."multiline_comment" IS \'Some \\n comment\'', sql + "'This column acts as it''s own comment'" in sql ) + assert 'COMMENT ON COLUMN "comments"."multiline_comment" IS \'Some \\n comment\'' in sql + finally: + await _teardown_tortoise() + - async def test_schema_no_db_constraint(self): - self.maxDiff = None - await self.init_for("tests.schema.models_no_db_constraint") +@pytest.mark.asyncio +async def test_asyncpg_schema_no_db_constraint(): + await _reset_tortoise() + try: + await _init_for_asyncpg("tests.schema.models_no_db_constraint") sql = get_schema_sql(connections.get("default"), safe=False) - self.assertEqual( - sql.strip(), - r"""CREATE TABLE "team" ( + assert ( + sql.strip() + == r"""CREATE TABLE "team" ( "name" VARCHAR(50) NOT NULL PRIMARY KEY, "key" INT NOT NULL, "manager_id" VARCHAR(50) @@ -880,16 +1130,21 @@ async def test_schema_no_db_constraint(self): "team_id" VARCHAR(50) NOT NULL ); COMMENT ON TABLE "teamevents" IS 'How participants relate'; -CREATE UNIQUE INDEX "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id");""", +CREATE UNIQUE INDEX "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id");""" ) + finally: + await _teardown_tortoise() - async def test_schema(self): - self.maxDiff = None - await self.init_for("tests.schema.models_schema_create") + +@pytest.mark.asyncio +async def test_asyncpg_schema(): + await _reset_tortoise() + try: + await _init_for_asyncpg("tests.schema.models_schema_create") sql = get_schema_sql(connections.get("default"), safe=False) - self.assertEqual( - sql.strip(), - """ + assert ( + sql.strip() + == """ CREATE TABLE "company" ( "id" SERIAL NOT NULL PRIMARY KEY, "name" TEXT NOT NULL, @@ -985,16 +1240,21 @@ async def test_schema(self): ); COMMENT ON TABLE "teamevents" IS 'How participants relate'; CREATE UNIQUE INDEX "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id"); -""".strip(), +""".strip() ) + finally: + await _teardown_tortoise() - async def test_schema_safe(self): - self.maxDiff = None - await self.init_for("tests.schema.models_schema_create") + +@pytest.mark.asyncio +async def test_asyncpg_schema_safe(): + await _reset_tortoise() + try: + await _init_for_asyncpg("tests.schema.models_schema_create") sql = get_schema_sql(connections.get("default"), safe=True) - self.assertEqual( - sql.strip(), - """ + assert ( + sql.strip() + == """ CREATE TABLE IF NOT EXISTS "company" ( "id" SERIAL NOT NULL PRIMARY KEY, "name" TEXT NOT NULL, @@ -1090,15 +1350,21 @@ async def test_schema_safe(self): ); COMMENT ON TABLE "teamevents" IS 'How participants relate'; CREATE UNIQUE INDEX IF NOT EXISTS "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id"); -""".strip(), +""".strip() ) + finally: + await _teardown_tortoise() - async def test_index_unsafe(self): - await self.init_for("tests.schema.models_postgres_index") + +@pytest.mark.asyncio +async def test_asyncpg_index_unsafe(): + await _reset_tortoise() + try: + await _init_for_asyncpg("tests.schema.models_postgres_index") sql = get_schema_sql(connections.get("default"), safe=False) - self.assertEqual( - sql, - """CREATE TABLE "index" ( + assert ( + sql + == """CREATE TABLE "index" ( "id" SERIAL NOT NULL PRIMARY KEY, "bloom" VARCHAR(200) NOT NULL, "brin" VARCHAR(200) NOT NULL, @@ -1117,15 +1383,21 @@ async def test_index_unsafe(self): CREATE INDEX "idx_index_sp_gist_2c0bad" ON "index" USING SPGIST ("sp_gist"); CREATE INDEX "idx_index_hash_cfe6b5" ON "index" USING HASH ("hash"); CREATE INDEX "idx_index_partial_c5be6a" ON "index" ("partial") WHERE id = 1; -CREATE INDEX "idx_index_(TO_TSV_50a2c7" ON "index" USING GIN ((TO_TSVECTOR('english',(("title" || ' ') || "body"))));""", +CREATE INDEX "idx_index_(TO_TSV_50a2c7" ON "index" USING GIN ((TO_TSVECTOR('english',(("title" || ' ') || "body"))));""" ) + finally: + await _teardown_tortoise() + - async def test_index_safe(self): - await self.init_for("tests.schema.models_postgres_index") +@pytest.mark.asyncio +async def test_asyncpg_index_safe(): + await _reset_tortoise() + try: + await _init_for_asyncpg("tests.schema.models_postgres_index") sql = get_schema_sql(connections.get("default"), safe=True) - self.assertEqual( - sql, - """CREATE TABLE IF NOT EXISTS "index" ( + assert ( + sql + == """CREATE TABLE IF NOT EXISTS "index" ( "id" SERIAL NOT NULL PRIMARY KEY, "bloom" VARCHAR(200) NOT NULL, "brin" VARCHAR(200) NOT NULL, @@ -1144,16 +1416,21 @@ async def test_index_safe(self): CREATE INDEX IF NOT EXISTS "idx_index_sp_gist_2c0bad" ON "index" USING SPGIST ("sp_gist"); CREATE INDEX IF NOT EXISTS "idx_index_hash_cfe6b5" ON "index" USING HASH ("hash"); CREATE INDEX IF NOT EXISTS "idx_index_partial_c5be6a" ON "index" ("partial") WHERE id = 1; -CREATE INDEX IF NOT EXISTS "idx_index_(TO_TSV_50a2c7" ON "index" USING GIN ((TO_TSVECTOR('english',(("title" || ' ') || "body"))));""", +CREATE INDEX IF NOT EXISTS "idx_index_(TO_TSV_50a2c7" ON "index" USING GIN ((TO_TSVECTOR('english',(("title" || ' ') || "body"))));""" ) + finally: + await _teardown_tortoise() + - async def test_m2m_no_auto_create(self): - self.maxDiff = None - await self.init_for("tests.schema.models_no_auto_create_m2m") +@pytest.mark.asyncio +async def test_asyncpg_m2m_no_auto_create(): + await _reset_tortoise() + try: + await _init_for_asyncpg("tests.schema.models_no_auto_create_m2m") sql = get_schema_sql(connections.get("default"), safe=False) - self.assertEqual( - sql.strip(), - r"""CREATE TABLE "team" ( + assert ( + sql.strip() + == r"""CREATE TABLE "team" ( "name" VARCHAR(50) NOT NULL PRIMARY KEY, "key" INT NOT NULL, "manager_id" VARCHAR(50) REFERENCES "team" ("name") ON DELETE CASCADE @@ -1199,15 +1476,21 @@ async def test_m2m_no_auto_create(self): "team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE ); CREATE UNIQUE INDEX "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id"); -""".strip(), +""".strip() ) + finally: + await _teardown_tortoise() + - async def test_pgfields_unsafe(self): - await self.init_for("tests.schema.models_postgres_fields") +@pytest.mark.asyncio +async def test_asyncpg_pgfields_unsafe(): + await _reset_tortoise() + try: + await _init_for_asyncpg("tests.schema.models_postgres_fields") sql = get_schema_sql(connections.get("default"), safe=False) - self.assertEqual( - sql, - """CREATE TABLE "postgres_fields" ( + assert ( + sql + == """CREATE TABLE "postgres_fields" ( "id" SERIAL NOT NULL PRIMARY KEY, "tsvector" TSVECTOR NOT NULL, "text_array" TEXT[] NOT NULL DEFAULT '{"a","b","c"}', @@ -1215,15 +1498,21 @@ async def test_pgfields_unsafe(self): "int_array" INT[] DEFAULT '{1,2,3}', "real_array" REAL[] NOT NULL DEFAULT '{1.1,2.2,3.3}' ); -COMMENT ON COLUMN "postgres_fields"."real_array" IS 'this is array of real numbers';""", +COMMENT ON COLUMN "postgres_fields"."real_array" IS 'this is array of real numbers';""" ) + finally: + await _teardown_tortoise() - async def test_pgfields_safe(self): - await self.init_for("tests.schema.models_postgres_fields") + +@pytest.mark.asyncio +async def test_asyncpg_pgfields_safe(): + await _reset_tortoise() + try: + await _init_for_asyncpg("tests.schema.models_postgres_fields") sql = get_schema_sql(connections.get("default"), safe=True) - self.assertEqual( - sql, - """CREATE TABLE IF NOT EXISTS "postgres_fields" ( + assert ( + sql + == """CREATE TABLE IF NOT EXISTS "postgres_fields" ( "id" SERIAL NOT NULL PRIMARY KEY, "tsvector" TSVECTOR NOT NULL, "text_array" TEXT[] NOT NULL DEFAULT '{"a","b","c"}', @@ -1231,57 +1520,513 @@ async def test_pgfields_safe(self): "int_array" INT[] DEFAULT '{1,2,3}', "real_array" REAL[] NOT NULL DEFAULT '{1.1,2.2,3.3}' ); -COMMENT ON COLUMN "postgres_fields"."real_array" IS 'this is array of real numbers';""", +COMMENT ON COLUMN "postgres_fields"."real_array" IS 'this is array of real numbers';""" ) + finally: + await _teardown_tortoise() -class TestGenerateSchemaAsyncpg(GenerateSchemaPostgresSQL): - async def init_for(self, module: str, safe=False) -> None: - try: - with patch("asyncpg.create_pool", new=MagicMock()): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.asyncpg", - "credentials": { - "database": "test", - "host": "127.0.0.1", - "password": "foomip", - "port": 5432, - "user": "root", - }, - } - }, - "apps": {"models": {"models": [module], "default_connection": "default"}}, - } - ) - self.sqls = get_schema_sql(connections.get("default"), safe).split("; ") - except ImportError: - raise test.SkipTest("asyncpg not installed") - - -class TestGenerateSchemaPsycopg(GenerateSchemaPostgresSQL): - async def init_for(self, module: str, safe=False) -> None: - try: - with patch("psycopg_pool.AsyncConnectionPool.open", new=MagicMock()): - await Tortoise.init( - { - "connections": { - "default": { - "engine": "tortoise.backends.psycopg", - "credentials": { - "database": "test", - "host": "127.0.0.1", - "password": "foomip", - "port": 5432, - "user": "root", - }, - } - }, - "apps": {"models": {"models": [module], "default_connection": "default"}}, - } - ) - self.sqls = get_schema_sql(connections.get("default"), safe).split("; ") - except ImportError: - raise test.SkipTest("psycopg not installed") +# ============================================================================ +# PostgreSQL Tests (psycopg) +# ============================================================================ + + +async def _init_for_psycopg(module: str, safe: bool = False) -> list[str]: + """Initialize Tortoise for psycopg and return SQL statements.""" + try: + with patch("psycopg_pool.AsyncConnectionPool.open", new=MagicMock()): + await Tortoise.init( + { + "connections": { + "default": { + "engine": "tortoise.backends.psycopg", + "credentials": { + "database": "test", + "host": "127.0.0.1", + "password": "foomip", + "port": 5432, + "user": "root", + }, + } + }, + "apps": {"models": {"models": [module], "default_connection": "default"}}, + } + ) + return get_schema_sql(connections.get("default"), safe).split("; ") + except ImportError: + pytest.skip("psycopg not installed") + + +@pytest.mark.asyncio +async def test_psycopg_noid(): + await _reset_tortoise() + try: + sqls = await _init_for_psycopg("tests.testmodels") + sql = _get_sql(sqls, '"noid"') + assert '"name" VARCHAR(255)' in sql + assert '"id" SERIAL NOT NULL PRIMARY KEY' in sql + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_psycopg_table_and_row_comment_generation(): + await _reset_tortoise() + try: + sqls = await _init_for_psycopg("tests.testmodels") + sql = _get_sql(sqls, "comments") + assert "COMMENT ON TABLE \"comments\" IS 'Test Table comment'" in sql + assert ( + 'COMMENT ON COLUMN "comments"."escaped_comment_field" IS ' + "'This column acts as it''s own comment'" in sql + ) + assert 'COMMENT ON COLUMN "comments"."multiline_comment" IS \'Some \\n comment\'' in sql + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_psycopg_schema_no_db_constraint(): + await _reset_tortoise() + try: + await _init_for_psycopg("tests.schema.models_no_db_constraint") + sql = get_schema_sql(connections.get("default"), safe=False) + assert ( + sql.strip() + == r"""CREATE TABLE "team" ( + "name" VARCHAR(50) NOT NULL PRIMARY KEY, + "key" INT NOT NULL, + "manager_id" VARCHAR(50) +); +CREATE INDEX "idx_team_manager_676134" ON "team" ("manager_id", "key"); +CREATE INDEX "idx_team_manager_ef8f69" ON "team" ("manager_id", "name"); +COMMENT ON COLUMN "team"."name" IS 'The TEAM name (and PK)'; +COMMENT ON TABLE "team" IS 'The TEAMS!'; +CREATE TABLE "tournament" ( + "tid" SMALLSERIAL NOT NULL PRIMARY KEY, + "name" VARCHAR(100) NOT NULL, + "created" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX "idx_tournament_name_6fe200" ON "tournament" ("name"); +COMMENT ON COLUMN "tournament"."name" IS 'Tournament name'; +COMMENT ON COLUMN "tournament"."created" IS 'Created */''`/* datetime'; +COMMENT ON TABLE "tournament" IS 'What Tournaments */''`/* we have'; +CREATE TABLE "event" ( + "id" BIGSERIAL NOT NULL PRIMARY KEY, + "name" TEXT NOT NULL, + "modified" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + "prize" DECIMAL(10,2), + "token" VARCHAR(100) NOT NULL UNIQUE, + "key" VARCHAR(100) NOT NULL, + "tournament_id" SMALLINT NOT NULL, + CONSTRAINT "uid_event_name_c6f89f" UNIQUE ("name", "prize"), + CONSTRAINT "uid_event_tournam_a5b730" UNIQUE ("tournament_id", "key") +); +COMMENT ON COLUMN "event"."id" IS 'Event ID'; +COMMENT ON COLUMN "event"."token" IS 'Unique token'; +COMMENT ON COLUMN "event"."tournament_id" IS 'FK to tournament'; +COMMENT ON TABLE "event" IS 'This table contains a list of all the events'; +CREATE TABLE "team_team" ( + "team_rel_id" VARCHAR(50) NOT NULL, + "team_id" VARCHAR(50) NOT NULL +); +CREATE UNIQUE INDEX "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id"); +CREATE TABLE "teamevents" ( + "event_id" BIGINT NOT NULL, + "team_id" VARCHAR(50) NOT NULL +); +COMMENT ON TABLE "teamevents" IS 'How participants relate'; +CREATE UNIQUE INDEX "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id");""" + ) + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_psycopg_schema(): + await _reset_tortoise() + try: + await _init_for_psycopg("tests.schema.models_schema_create") + sql = get_schema_sql(connections.get("default"), safe=False) + assert ( + sql.strip() + == """ +CREATE TABLE "company" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "name" TEXT NOT NULL, + "uuid" UUID NOT NULL UNIQUE +); +CREATE TABLE "defaultpk" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "val" INT NOT NULL +); +CREATE TABLE "employee" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "name" TEXT NOT NULL, + "company_id" UUID NOT NULL REFERENCES "company" ("uuid") ON DELETE CASCADE +); +CREATE TABLE "inheritedmodel" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "zero" INT NOT NULL, + "one" VARCHAR(40), + "new_field" VARCHAR(100) NOT NULL, + "two" VARCHAR(40) NOT NULL, + "name" TEXT NOT NULL +); +CREATE TABLE "sometable" ( + "sometable_id" SERIAL NOT NULL PRIMARY KEY, + "some_chars_table" VARCHAR(255) NOT NULL, + "fk_sometable" INT REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE +); +CREATE INDEX "idx_sometable_some_ch_3d69eb" ON "sometable" ("some_chars_table"); +CREATE TABLE "team" ( + "name" VARCHAR(50) NOT NULL PRIMARY KEY, + "key" INT NOT NULL, + "manager_id" VARCHAR(50) REFERENCES "team" ("name") ON DELETE CASCADE +); +CREATE INDEX "idx_team_manager_676134" ON "team" ("manager_id", "key"); +CREATE INDEX "idx_team_manager_ef8f69" ON "team" ("manager_id", "name"); +COMMENT ON COLUMN "team"."name" IS 'The TEAM name (and PK)'; +COMMENT ON TABLE "team" IS 'The TEAMS!'; +CREATE TABLE "teamaddress" ( + "city" VARCHAR(50) NOT NULL, + "country" VARCHAR(50) NOT NULL, + "street" VARCHAR(128) NOT NULL, + "team_id" VARCHAR(50) NOT NULL PRIMARY KEY REFERENCES "team" ("name") ON DELETE CASCADE +); +COMMENT ON COLUMN "teamaddress"."city" IS 'City'; +COMMENT ON COLUMN "teamaddress"."country" IS 'Country'; +COMMENT ON COLUMN "teamaddress"."street" IS 'Street Address'; +COMMENT ON TABLE "teamaddress" IS 'The Team''s address'; +CREATE TABLE "tournament" ( + "tid" SMALLSERIAL NOT NULL PRIMARY KEY, + "name" VARCHAR(100) NOT NULL, + "created" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX "idx_tournament_name_6fe200" ON "tournament" ("name"); +COMMENT ON COLUMN "tournament"."name" IS 'Tournament name'; +COMMENT ON COLUMN "tournament"."created" IS 'Created */''`/* datetime'; +COMMENT ON TABLE "tournament" IS 'What Tournaments */''`/* we have'; +CREATE TABLE "event" ( + "id" BIGSERIAL NOT NULL PRIMARY KEY, + "name" TEXT NOT NULL, + "modified" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + "prize" DECIMAL(10,2), + "token" VARCHAR(100) NOT NULL UNIQUE, + "key" VARCHAR(100) NOT NULL, + "tournament_id" SMALLINT NOT NULL REFERENCES "tournament" ("tid") ON DELETE CASCADE, + CONSTRAINT "uid_event_name_c6f89f" UNIQUE ("name", "prize"), + CONSTRAINT "uid_event_tournam_a5b730" UNIQUE ("tournament_id", "key") +); +COMMENT ON COLUMN "event"."id" IS 'Event ID'; +COMMENT ON COLUMN "event"."token" IS 'Unique token'; +COMMENT ON COLUMN "event"."tournament_id" IS 'FK to tournament'; +COMMENT ON TABLE "event" IS 'This table contains a list of all the events'; +CREATE TABLE "venueinformation" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "name" VARCHAR(128) NOT NULL, + "capacity" INT NOT NULL, + "rent" DOUBLE PRECISION NOT NULL, + "team_id" VARCHAR(50) UNIQUE REFERENCES "team" ("name") ON DELETE SET NULL +); +COMMENT ON COLUMN "venueinformation"."capacity" IS 'No. of seats'; +CREATE TABLE "sometable_self" ( + "backward_sts" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE, + "sts_forward" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE +); +CREATE UNIQUE INDEX "uidx_sometable_s_backwar_fc8fc8" ON "sometable_self" ("backward_sts", "sts_forward"); +CREATE TABLE "team_team" ( + "team_rel_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE, + "team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE +); +CREATE UNIQUE INDEX "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id"); +CREATE TABLE "teamevents" ( + "event_id" BIGINT NOT NULL REFERENCES "event" ("id") ON DELETE SET NULL, + "team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE SET NULL +); +COMMENT ON TABLE "teamevents" IS 'How participants relate'; +CREATE UNIQUE INDEX "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id"); +""".strip() + ) + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_psycopg_schema_safe(): + await _reset_tortoise() + try: + await _init_for_psycopg("tests.schema.models_schema_create") + sql = get_schema_sql(connections.get("default"), safe=True) + assert ( + sql.strip() + == """ +CREATE TABLE IF NOT EXISTS "company" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "name" TEXT NOT NULL, + "uuid" UUID NOT NULL UNIQUE +); +CREATE TABLE IF NOT EXISTS "defaultpk" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "val" INT NOT NULL +); +CREATE TABLE IF NOT EXISTS "employee" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "name" TEXT NOT NULL, + "company_id" UUID NOT NULL REFERENCES "company" ("uuid") ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS "inheritedmodel" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "zero" INT NOT NULL, + "one" VARCHAR(40), + "new_field" VARCHAR(100) NOT NULL, + "two" VARCHAR(40) NOT NULL, + "name" TEXT NOT NULL +); +CREATE TABLE IF NOT EXISTS "sometable" ( + "sometable_id" SERIAL NOT NULL PRIMARY KEY, + "some_chars_table" VARCHAR(255) NOT NULL, + "fk_sometable" INT REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS "idx_sometable_some_ch_3d69eb" ON "sometable" ("some_chars_table"); +CREATE TABLE IF NOT EXISTS "team" ( + "name" VARCHAR(50) NOT NULL PRIMARY KEY, + "key" INT NOT NULL, + "manager_id" VARCHAR(50) REFERENCES "team" ("name") ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS "idx_team_manager_676134" ON "team" ("manager_id", "key"); +CREATE INDEX IF NOT EXISTS "idx_team_manager_ef8f69" ON "team" ("manager_id", "name"); +COMMENT ON COLUMN "team"."name" IS 'The TEAM name (and PK)'; +COMMENT ON TABLE "team" IS 'The TEAMS!'; +CREATE TABLE IF NOT EXISTS "teamaddress" ( + "city" VARCHAR(50) NOT NULL, + "country" VARCHAR(50) NOT NULL, + "street" VARCHAR(128) NOT NULL, + "team_id" VARCHAR(50) NOT NULL PRIMARY KEY REFERENCES "team" ("name") ON DELETE CASCADE +); +COMMENT ON COLUMN "teamaddress"."city" IS 'City'; +COMMENT ON COLUMN "teamaddress"."country" IS 'Country'; +COMMENT ON COLUMN "teamaddress"."street" IS 'Street Address'; +COMMENT ON TABLE "teamaddress" IS 'The Team''s address'; +CREATE TABLE IF NOT EXISTS "tournament" ( + "tid" SMALLSERIAL NOT NULL PRIMARY KEY, + "name" VARCHAR(100) NOT NULL, + "created" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX IF NOT EXISTS "idx_tournament_name_6fe200" ON "tournament" ("name"); +COMMENT ON COLUMN "tournament"."name" IS 'Tournament name'; +COMMENT ON COLUMN "tournament"."created" IS 'Created */''`/* datetime'; +COMMENT ON TABLE "tournament" IS 'What Tournaments */''`/* we have'; +CREATE TABLE IF NOT EXISTS "event" ( + "id" BIGSERIAL NOT NULL PRIMARY KEY, + "name" TEXT NOT NULL, + "modified" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + "prize" DECIMAL(10,2), + "token" VARCHAR(100) NOT NULL UNIQUE, + "key" VARCHAR(100) NOT NULL, + "tournament_id" SMALLINT NOT NULL REFERENCES "tournament" ("tid") ON DELETE CASCADE, + CONSTRAINT "uid_event_name_c6f89f" UNIQUE ("name", "prize"), + CONSTRAINT "uid_event_tournam_a5b730" UNIQUE ("tournament_id", "key") +); +COMMENT ON COLUMN "event"."id" IS 'Event ID'; +COMMENT ON COLUMN "event"."token" IS 'Unique token'; +COMMENT ON COLUMN "event"."tournament_id" IS 'FK to tournament'; +COMMENT ON TABLE "event" IS 'This table contains a list of all the events'; +CREATE TABLE IF NOT EXISTS "venueinformation" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "name" VARCHAR(128) NOT NULL, + "capacity" INT NOT NULL, + "rent" DOUBLE PRECISION NOT NULL, + "team_id" VARCHAR(50) UNIQUE REFERENCES "team" ("name") ON DELETE SET NULL +); +COMMENT ON COLUMN "venueinformation"."capacity" IS 'No. of seats'; +CREATE TABLE IF NOT EXISTS "sometable_self" ( + "backward_sts" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE, + "sts_forward" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE +); +CREATE UNIQUE INDEX IF NOT EXISTS "uidx_sometable_s_backwar_fc8fc8" ON "sometable_self" ("backward_sts", "sts_forward"); +CREATE TABLE IF NOT EXISTS "team_team" ( + "team_rel_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE, + "team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE +); +CREATE UNIQUE INDEX IF NOT EXISTS "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id"); +CREATE TABLE IF NOT EXISTS "teamevents" ( + "event_id" BIGINT NOT NULL REFERENCES "event" ("id") ON DELETE SET NULL, + "team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE SET NULL +); +COMMENT ON TABLE "teamevents" IS 'How participants relate'; +CREATE UNIQUE INDEX IF NOT EXISTS "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id"); +""".strip() + ) + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_psycopg_index_unsafe(): + await _reset_tortoise() + try: + await _init_for_psycopg("tests.schema.models_postgres_index") + sql = get_schema_sql(connections.get("default"), safe=False) + assert ( + sql + == """CREATE TABLE "index" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "bloom" VARCHAR(200) NOT NULL, + "brin" VARCHAR(200) NOT NULL, + "gin" TSVECTOR NOT NULL, + "gist" TSVECTOR NOT NULL, + "sp_gist" VARCHAR(200) NOT NULL, + "hash" VARCHAR(200) NOT NULL, + "partial" VARCHAR(200) NOT NULL, + "title" TEXT NOT NULL, + "body" TEXT NOT NULL +); +CREATE INDEX "idx_index_bloom_280137" ON "index" USING BLOOM ("bloom"); +CREATE INDEX "idx_index_brin_a54a00" ON "index" USING BRIN ("brin"); +CREATE INDEX "idx_index_gin_a403ee" ON "index" USING GIN ("gin"); +CREATE INDEX "idx_index_gist_c807bf" ON "index" USING GIST ("gist"); +CREATE INDEX "idx_index_sp_gist_2c0bad" ON "index" USING SPGIST ("sp_gist"); +CREATE INDEX "idx_index_hash_cfe6b5" ON "index" USING HASH ("hash"); +CREATE INDEX "idx_index_partial_c5be6a" ON "index" ("partial") WHERE id = 1; +CREATE INDEX "idx_index_(TO_TSV_50a2c7" ON "index" USING GIN ((TO_TSVECTOR('english',(("title" || ' ') || "body"))));""" + ) + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_psycopg_index_safe(): + await _reset_tortoise() + try: + await _init_for_psycopg("tests.schema.models_postgres_index") + sql = get_schema_sql(connections.get("default"), safe=True) + assert ( + sql + == """CREATE TABLE IF NOT EXISTS "index" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "bloom" VARCHAR(200) NOT NULL, + "brin" VARCHAR(200) NOT NULL, + "gin" TSVECTOR NOT NULL, + "gist" TSVECTOR NOT NULL, + "sp_gist" VARCHAR(200) NOT NULL, + "hash" VARCHAR(200) NOT NULL, + "partial" VARCHAR(200) NOT NULL, + "title" TEXT NOT NULL, + "body" TEXT NOT NULL +); +CREATE INDEX IF NOT EXISTS "idx_index_bloom_280137" ON "index" USING BLOOM ("bloom"); +CREATE INDEX IF NOT EXISTS "idx_index_brin_a54a00" ON "index" USING BRIN ("brin"); +CREATE INDEX IF NOT EXISTS "idx_index_gin_a403ee" ON "index" USING GIN ("gin"); +CREATE INDEX IF NOT EXISTS "idx_index_gist_c807bf" ON "index" USING GIST ("gist"); +CREATE INDEX IF NOT EXISTS "idx_index_sp_gist_2c0bad" ON "index" USING SPGIST ("sp_gist"); +CREATE INDEX IF NOT EXISTS "idx_index_hash_cfe6b5" ON "index" USING HASH ("hash"); +CREATE INDEX IF NOT EXISTS "idx_index_partial_c5be6a" ON "index" ("partial") WHERE id = 1; +CREATE INDEX IF NOT EXISTS "idx_index_(TO_TSV_50a2c7" ON "index" USING GIN ((TO_TSVECTOR('english',(("title" || ' ') || "body"))));""" + ) + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_psycopg_m2m_no_auto_create(): + await _reset_tortoise() + try: + await _init_for_psycopg("tests.schema.models_no_auto_create_m2m") + sql = get_schema_sql(connections.get("default"), safe=False) + assert ( + sql.strip() + == r"""CREATE TABLE "team" ( + "name" VARCHAR(50) NOT NULL PRIMARY KEY, + "key" INT NOT NULL, + "manager_id" VARCHAR(50) REFERENCES "team" ("name") ON DELETE CASCADE +); +CREATE INDEX "idx_team_manager_676134" ON "team" ("manager_id", "key"); +CREATE INDEX "idx_team_manager_ef8f69" ON "team" ("manager_id", "name"); +COMMENT ON COLUMN "team"."name" IS 'The TEAM name (and PK)'; +COMMENT ON TABLE "team" IS 'The TEAMS!'; +CREATE TABLE "tournament" ( + "tid" SMALLSERIAL NOT NULL PRIMARY KEY, + "name" VARCHAR(100) NOT NULL, + "created" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX "idx_tournament_name_6fe200" ON "tournament" ("name"); +COMMENT ON COLUMN "tournament"."name" IS 'Tournament name'; +COMMENT ON COLUMN "tournament"."created" IS 'Created */''`/* datetime'; +COMMENT ON TABLE "tournament" IS 'What Tournaments */''`/* we have'; +CREATE TABLE "event" ( + "id" BIGSERIAL NOT NULL PRIMARY KEY, + "name" TEXT NOT NULL, + "modified" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + "prize" DECIMAL(10,2), + "token" VARCHAR(100) NOT NULL UNIQUE, + "key" VARCHAR(100) NOT NULL, + "tournament_id" SMALLINT NOT NULL REFERENCES "tournament" ("tid") ON DELETE CASCADE, + CONSTRAINT "uid_event_name_c6f89f" UNIQUE ("name", "prize"), + CONSTRAINT "uid_event_tournam_a5b730" UNIQUE ("tournament_id", "key") +); +COMMENT ON COLUMN "event"."id" IS 'Event ID'; +COMMENT ON COLUMN "event"."token" IS 'Unique token'; +COMMENT ON COLUMN "event"."tournament_id" IS 'FK to tournament'; +COMMENT ON TABLE "event" IS 'This table contains a list of all the events'; +CREATE TABLE "teamevents" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "score" INT NOT NULL, + "event_id" BIGINT NOT NULL REFERENCES "event" ("id") ON DELETE CASCADE, + "team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE, + CONSTRAINT "uid_teamevents_team_id_9e89fc" UNIQUE ("team_id", "event_id") +); +COMMENT ON TABLE "teamevents" IS 'How participants relate'; +CREATE TABLE "team_team" ( + "team_rel_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE, + "team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE +); +CREATE UNIQUE INDEX "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id"); +""".strip() + ) + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_psycopg_pgfields_unsafe(): + await _reset_tortoise() + try: + await _init_for_psycopg("tests.schema.models_postgres_fields") + sql = get_schema_sql(connections.get("default"), safe=False) + assert ( + sql + == """CREATE TABLE "postgres_fields" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "tsvector" TSVECTOR NOT NULL, + "text_array" TEXT[] NOT NULL DEFAULT '{"a","b","c"}', + "varchar_array" VARCHAR(32)[] NOT NULL DEFAULT '{"aa","bbb","cccc"}', + "int_array" INT[] DEFAULT '{1,2,3}', + "real_array" REAL[] NOT NULL DEFAULT '{1.1,2.2,3.3}' +); +COMMENT ON COLUMN "postgres_fields"."real_array" IS 'this is array of real numbers';""" + ) + finally: + await _teardown_tortoise() + + +@pytest.mark.asyncio +async def test_psycopg_pgfields_safe(): + await _reset_tortoise() + try: + await _init_for_psycopg("tests.schema.models_postgres_fields") + sql = get_schema_sql(connections.get("default"), safe=True) + assert ( + sql + == """CREATE TABLE IF NOT EXISTS "postgres_fields" ( + "id" SERIAL NOT NULL PRIMARY KEY, + "tsvector" TSVECTOR NOT NULL, + "text_array" TEXT[] NOT NULL DEFAULT '{"a","b","c"}', + "varchar_array" VARCHAR(32)[] NOT NULL DEFAULT '{"aa","bbb","cccc"}', + "int_array" INT[] DEFAULT '{1,2,3}', + "real_array" REAL[] NOT NULL DEFAULT '{1.1,2.2,3.3}' +); +COMMENT ON COLUMN "postgres_fields"."real_array" IS 'this is array of real numbers';""" + ) + finally: + await _teardown_tortoise() diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py index 6297fba7e..fe608c1d2 100644 --- a/tests/test_aggregation.py +++ b/tests/test_aggregation.py @@ -1,5 +1,7 @@ from decimal import Decimal +import pytest + from tests.testmodels import ( Author, Book, @@ -16,298 +18,325 @@ from tortoise.functions import Avg, Coalesce, Concat, Count, Lower, Max, Min, Sum, Trim -class TestAggregation(test.TestCase): - async def test_aggregation(self): - tournament = Tournament(name="New Tournament") - await tournament.save() - await Tournament.create(name="Second tournament") - await Event(name="Without participants", tournament_id=tournament.id).save() - event = Event(name="Test", tournament_id=tournament.id) - await event.save() - participants = [] - for i in range(2): - team = Team(name=f"Team {(i + 1)}") - await team.save() - participants.append(team) - await event.participants.add(participants[0], participants[1]) - await event.participants.add(participants[0], participants[1]) - - tournaments_with_count = ( - await Tournament.all() - .annotate(events_count=Count("events")) - .filter(events_count__gte=1) - ) - self.assertEqual(len(tournaments_with_count), 1) - self.assertEqual(tournaments_with_count[0].events_count, 2) - - event_with_lowest_team_id = ( - await Event.filter(event_id=event.event_id) - .first() - .annotate(lowest_team_id=Min("participants__id")) - ) - self.assertEqual(event_with_lowest_team_id.lowest_team_id, participants[0].id) - - ordered_tournaments = ( - await Tournament.all().annotate(events_count=Count("events")).order_by("events_count") - ) - self.assertEqual(len(ordered_tournaments), 2) - self.assertEqual(ordered_tournaments[1].id, tournament.id) - event_with_annotation = ( - await Event.all().annotate(tournament_test_id=Sum("tournament__id")).first() - ) - self.assertEqual( - event_with_annotation.tournament_test_id, - event_with_annotation.tournament_id, - ) - - with self.assertRaisesRegex(FieldError, "name__id not resolvable"): - await Event.all().annotate(tournament_test_id=Sum("name__id")).first() - - async def test_nested_aggregation_in_annotation(self): - tournament = await Tournament.create(name="0") - await Tournament.create(name="1") - event = await Event.create(name="2", tournament=tournament) - - team_first = await Team.create(name="First") - team_second = await Team.create(name="Second") - - await event.participants.add(team_second) - await event.participants.add(team_first) - - tournaments = await Tournament.annotate( - events_participants_count=Count("events__participants") - ).filter(id=tournament.id) - self.assertEqual(tournaments[0].events_participants_count, 2) - - async def test_aggregation_with_distinct(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Event 1", tournament=tournament) - await Event.create(name="Event 2", tournament=tournament) - await MinRelation.create(tournament=tournament) - - tournament_2 = await Tournament.create(name="New Tournament") - await Event.create(name="Event 1", tournament=tournament_2) - await Event.create(name="Event 2", tournament=tournament_2) - await Event.create(name="Event 3", tournament=tournament_2) - await MinRelation.create(tournament=tournament_2) - await MinRelation.create(tournament=tournament_2) - - school_with_distinct_count = ( - await Tournament.filter(id=tournament_2.id) - .annotate( - events_count=Count("events", distinct=True), - minrelations_count=Count("minrelations", distinct=True), - ) - .first() - ) - - self.assertEqual(school_with_distinct_count.events_count, 3) - self.assertEqual(school_with_distinct_count.minrelations_count, 2) - - async def test_aggregation_with_filter(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Event 1", tournament=tournament) - await Event.create(name="Event 2", tournament=tournament) - await Event.create(name="Event 3", tournament=tournament) - - tournament_with_filter = ( - await Tournament.all() - .annotate( - all=Count("events", _filter=Q(name="New Tournament")), - one=Count("events", _filter=Q(events__name="Event 1")), - two=Count("events", _filter=Q(events__name__not="Event 1")), - ) - .first() - ) - - self.assertEqual(tournament_with_filter.all, 3) - self.assertEqual(tournament_with_filter.one, 1) - self.assertEqual(tournament_with_filter.two, 2) - - async def test_group_aggregation(self): - author = await Author.create(name="Some One") - await Book.create(name="First!", author=author, rating=4) - await Book.create(name="Second!", author=author, rating=3) - await Book.create(name="Third!", author=author, rating=3) - - authors = await Author.all().annotate(average_rating=Avg("books__rating")) - self.assertAlmostEqual(authors[0].average_rating, 3.3333333333) - - authors = await Author.all().annotate(average_rating=Avg("books__rating")).values() - self.assertAlmostEqual(authors[0]["average_rating"], 3.3333333333) - - authors = ( - await Author.all() - .annotate(average_rating=Avg("books__rating")) - .values("id", "name", "average_rating") - ) - self.assertAlmostEqual(authors[0]["average_rating"], 3.3333333333) - - authors = await Author.all().annotate(average_rating=Avg("books__rating")).values_list() - self.assertAlmostEqual(authors[0][2], 3.3333333333) - - authors = ( - await Author.all() - .annotate(average_rating=Avg("books__rating")) - .values_list("id", "name", "average_rating") - ) - self.assertAlmostEqual(authors[0][2], 3.3333333333) - - async def test_nested_functions(self): - author = await Author.create(name="Some One") - await Book.create(name="First!", author=author, rating=4) - await Book.create(name="Second!", author=author, rating=3) - await Book.create(name="Third!", author=author, rating=3) - ret = await Book.all().annotate(max_name=Lower(Max("name"))).values("max_name") - self.assertEqual(ret, [{"max_name": "third!"}]) - - @test.requireCapability(dialect=In("postgres", "mssql")) - async def test_concat_functions(self): - author = await Author.create(name="Some One") - await Book.create(name="Physics Book", author=author, rating=4, subject="physics ") - await Book.create(name="Mathematics Book", author=author, rating=3, subject=" mathematics") - await Book.create(name="No-subject Book", author=author, rating=3) - ret = ( - await Book.all() - .annotate(long_info=Max(Concat("name", "(", Coalesce(Trim("subject"), "others"), ")"))) - .values("long_info") - ) - self.assertEqual(ret, [{"long_info": "Physics Book(physics)"}]) - - async def test_count_after_aggregate(self): - author = await Author.create(name="1") - await Book.create(name="First!", author=author, rating=4) - await Book.create(name="Second!", author=author, rating=3) - await Book.create(name="Third!", author=author, rating=3) - - author2 = await Author.create(name="2") - await Book.create(name="F-2", author=author2, rating=3) - await Book.create(name="F-3", author=author2, rating=3) - - author3 = await Author.create(name="3") - await Book.create(name="F-4", author=author3, rating=3) - await Book.create(name="F-5", author=author3, rating=2) - ret = ( - await Author.all() - .annotate(average_rating=Avg("books__rating")) - .filter(average_rating__gte=3) - .count() - ) - - assert ret == 2 - - async def test_exist_after_aggregate(self): - author = await Author.create(name="1") - await Book.create(name="First!", author=author, rating=4) - await Book.create(name="Second!", author=author, rating=3) - await Book.create(name="Third!", author=author, rating=3) - - ret = ( - await Author.all() - .annotate(average_rating=Avg("books__rating")) - .filter(average_rating__gte=3) - .exists() - ) - - assert ret is True - - ret = ( - await Author.all() - .annotate(average_rating=Avg("books__rating")) - .filter(average_rating__gte=4) - .exists() +@pytest.mark.asyncio +async def test_aggregation(db): + tournament = Tournament(name="New Tournament") + await tournament.save() + await Tournament.create(name="Second tournament") + await Event(name="Without participants", tournament_id=tournament.id).save() + event = Event(name="Test", tournament_id=tournament.id) + await event.save() + participants = [] + for i in range(2): + team = Team(name=f"Team {(i + 1)}") + await team.save() + participants.append(team) + await event.participants.add(participants[0], participants[1]) + await event.participants.add(participants[0], participants[1]) + + tournaments_with_count = ( + await Tournament.all().annotate(events_count=Count("events")).filter(events_count__gte=1) + ) + assert len(tournaments_with_count) == 1 + assert tournaments_with_count[0].events_count == 2 + + event_with_lowest_team_id = ( + await Event.filter(event_id=event.event_id) + .first() + .annotate(lowest_team_id=Min("participants__id")) + ) + assert event_with_lowest_team_id.lowest_team_id == participants[0].id + + ordered_tournaments = ( + await Tournament.all().annotate(events_count=Count("events")).order_by("events_count") + ) + assert len(ordered_tournaments) == 2 + assert ordered_tournaments[1].id == tournament.id + event_with_annotation = ( + await Event.all().annotate(tournament_test_id=Sum("tournament__id")).first() + ) + assert event_with_annotation.tournament_test_id == event_with_annotation.tournament_id + + with pytest.raises(FieldError, match="name__id not resolvable"): + await Event.all().annotate(tournament_test_id=Sum("name__id")).first() + + +@pytest.mark.asyncio +async def test_nested_aggregation_in_annotation(db): + tournament = await Tournament.create(name="0") + await Tournament.create(name="1") + event = await Event.create(name="2", tournament=tournament) + + team_first = await Team.create(name="First") + team_second = await Team.create(name="Second") + + await event.participants.add(team_second) + await event.participants.add(team_first) + + tournaments = await Tournament.annotate( + events_participants_count=Count("events__participants") + ).filter(id=tournament.id) + assert tournaments[0].events_participants_count == 2 + + +@pytest.mark.asyncio +async def test_aggregation_with_distinct(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Event 1", tournament=tournament) + await Event.create(name="Event 2", tournament=tournament) + await MinRelation.create(tournament=tournament) + + tournament_2 = await Tournament.create(name="New Tournament") + await Event.create(name="Event 1", tournament=tournament_2) + await Event.create(name="Event 2", tournament=tournament_2) + await Event.create(name="Event 3", tournament=tournament_2) + await MinRelation.create(tournament=tournament_2) + await MinRelation.create(tournament=tournament_2) + + school_with_distinct_count = ( + await Tournament.filter(id=tournament_2.id) + .annotate( + events_count=Count("events", distinct=True), + minrelations_count=Count("minrelations", distinct=True), ) - assert ret is False - - async def test_count_after_aggregate_m2m(self): - tournament = await Tournament.create(name="1") - event1 = await Event.create(name="First!", tournament=tournament) - event2 = await Event.create(name="Second!", tournament=tournament) - event3 = await Event.create(name="Third!", tournament=tournament) - event4 = await Event.create(name="Fourth!", tournament=tournament) - - team1 = await Team.create(name="1") - team2 = await Team.create(name="2") - team3 = await Team.create(name="3") - - await event1.participants.add(team1, team2, team3) - await event2.participants.add(team1, team2) - await event3.participants.add(team1) - await event4.participants.add(team1, team2, team3) - - query = ( - Event.filter(participants__id__in=[team1.id, team2.id, team3.id]) - .annotate(count=Count("event_id")) - .filter(count=3) - .prefetch_related("participants") + .first() + ) + + assert school_with_distinct_count.events_count == 3 + assert school_with_distinct_count.minrelations_count == 2 + + +@pytest.mark.asyncio +async def test_aggregation_with_filter(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Event 1", tournament=tournament) + await Event.create(name="Event 2", tournament=tournament) + await Event.create(name="Event 3", tournament=tournament) + + tournament_with_filter = ( + await Tournament.all() + .annotate( + all=Count("events", _filter=Q(name="New Tournament")), + one=Count("events", _filter=Q(events__name="Event 1")), + two=Count("events", _filter=Q(events__name__not="Event 1")), ) - result = await query - assert len(result) == 2 + .first() + ) + + assert tournament_with_filter.all == 3 + assert tournament_with_filter.one == 1 + assert tournament_with_filter.two == 2 + + +@pytest.mark.asyncio +async def test_group_aggregation(db): + author = await Author.create(name="Some One") + await Book.create(name="First!", author=author, rating=4) + await Book.create(name="Second!", author=author, rating=3) + await Book.create(name="Third!", author=author, rating=3) + + authors = await Author.all().annotate(average_rating=Avg("books__rating")) + assert authors[0].average_rating == pytest.approx(3.3333333333, rel=1e-5) + + authors = await Author.all().annotate(average_rating=Avg("books__rating")).values() + assert authors[0]["average_rating"] == pytest.approx(3.3333333333, rel=1e-5) + + authors = ( + await Author.all() + .annotate(average_rating=Avg("books__rating")) + .values("id", "name", "average_rating") + ) + assert authors[0]["average_rating"] == pytest.approx(3.3333333333, rel=1e-5) + + authors = await Author.all().annotate(average_rating=Avg("books__rating")).values_list() + assert authors[0][2] == pytest.approx(3.3333333333, rel=1e-5) + + authors = ( + await Author.all() + .annotate(average_rating=Avg("books__rating")) + .values_list("id", "name", "average_rating") + ) + assert authors[0][2] == pytest.approx(3.3333333333, rel=1e-5) + + +@pytest.mark.asyncio +async def test_nested_functions(db): + author = await Author.create(name="Some One") + await Book.create(name="First!", author=author, rating=4) + await Book.create(name="Second!", author=author, rating=3) + await Book.create(name="Third!", author=author, rating=3) + ret = await Book.all().annotate(max_name=Lower(Max("name"))).values("max_name") + assert ret == [{"max_name": "third!"}] + + +@test.requireCapability(dialect=In("postgres", "mssql")) +@pytest.mark.asyncio +async def test_concat_functions(db): + author = await Author.create(name="Some One") + await Book.create(name="Physics Book", author=author, rating=4, subject="physics ") + await Book.create(name="Mathematics Book", author=author, rating=3, subject=" mathematics") + await Book.create(name="No-subject Book", author=author, rating=3) + ret = ( + await Book.all() + .annotate(long_info=Max(Concat("name", "(", Coalesce(Trim("subject"), "others"), ")"))) + .values("long_info") + ) + assert ret == [{"long_info": "Physics Book(physics)"}] + + +@pytest.mark.asyncio +async def test_count_after_aggregate(db): + author = await Author.create(name="1") + await Book.create(name="First!", author=author, rating=4) + await Book.create(name="Second!", author=author, rating=3) + await Book.create(name="Third!", author=author, rating=3) + + author2 = await Author.create(name="2") + await Book.create(name="F-2", author=author2, rating=3) + await Book.create(name="F-3", author=author2, rating=3) + + author3 = await Author.create(name="3") + await Book.create(name="F-4", author=author3, rating=3) + await Book.create(name="F-5", author=author3, rating=2) + ret = ( + await Author.all() + .annotate(average_rating=Avg("books__rating")) + .filter(average_rating__gte=3) + .count() + ) + + assert ret == 2 + + +@pytest.mark.asyncio +async def test_exist_after_aggregate(db): + author = await Author.create(name="1") + await Book.create(name="First!", author=author, rating=4) + await Book.create(name="Second!", author=author, rating=3) + await Book.create(name="Third!", author=author, rating=3) + + ret = ( + await Author.all() + .annotate(average_rating=Avg("books__rating")) + .filter(average_rating__gte=3) + .exists() + ) + + assert ret is True + + ret = ( + await Author.all() + .annotate(average_rating=Avg("books__rating")) + .filter(average_rating__gte=4) + .exists() + ) + assert ret is False + + +@pytest.mark.asyncio +async def test_count_after_aggregate_m2m(db): + tournament = await Tournament.create(name="1") + event1 = await Event.create(name="First!", tournament=tournament) + event2 = await Event.create(name="Second!", tournament=tournament) + event3 = await Event.create(name="Third!", tournament=tournament) + event4 = await Event.create(name="Fourth!", tournament=tournament) + + team1 = await Team.create(name="1") + team2 = await Team.create(name="2") + team3 = await Team.create(name="3") + + await event1.participants.add(team1, team2, team3) + await event2.participants.add(team1, team2) + await event3.participants.add(team1) + await event4.participants.add(team1, team2, team3) + + query = ( + Event.filter(participants__id__in=[team1.id, team2.id, team3.id]) + .annotate(count=Count("event_id")) + .filter(count=3) + .prefetch_related("participants") + ) + result = await query + assert len(result) == 2 + + res = await query.count() + assert res == 2 + + +@pytest.mark.asyncio +async def test_where_and_having(db): + author = await Author.create(name="1") + await Book.create(name="First!", author=author, rating=4) + await Book.create(name="Second!", author=author, rating=3) + await Book.create(name="Third!", author=author, rating=3) + + query = Book.exclude(name="First!").annotate(avg_rating=Avg("rating")).values("avg_rating") + result = await query + assert len(result) == 1 + assert result[0]["avg_rating"] == 3 + + +@pytest.mark.asyncio +async def test_count_without_matching(db) -> None: + await Tournament.create(name="Test") - res = await query.count() - assert res == 2 + query = Tournament.annotate(events_count=Count("events")).filter(events_count__gt=0).count() + result = await query + assert result == 0 - async def test_where_and_having(self): - author = await Author.create(name="1") - await Book.create(name="First!", author=author, rating=4) - await Book.create(name="Second!", author=author, rating=3) - await Book.create(name="Third!", author=author, rating=3) - query = Book.exclude(name="First!").annotate(avg_rating=Avg("rating")).values("avg_rating") - result = await query - assert len(result) == 1 - assert result[0]["avg_rating"] == 3 +@pytest.mark.asyncio +async def test_int_sum_on_models_with_validators(db) -> None: + await ValidatorModel.create(max_value=2) + await ValidatorModel.create(max_value=2) + + query = ValidatorModel.annotate(sum=Sum("max_value")).values("sum") + result = await query + assert result == [{"sum": 4}] - async def test_count_without_matching(self) -> None: - await Tournament.create(name="Test") - query = Tournament.annotate(events_count=Count("events")).filter(events_count__gt=0).count() - result = await query - assert result == 0 +@pytest.mark.asyncio +async def test_int_sum_math_on_models_with_validators(db) -> None: + await ValidatorModel.create(max_value=4) + await ValidatorModel.create(max_value=4) - async def test_int_sum_on_models_with_validators(self) -> None: - await ValidatorModel.create(max_value=2) - await ValidatorModel.create(max_value=2) + query = ValidatorModel.annotate(sum=Sum(F("max_value") * F("max_value"))).values("sum") + result = await query + assert result == [{"sum": 32}] - query = ValidatorModel.annotate(sum=Sum("max_value")).values("sum") - result = await query - self.assertEqual(result, [{"sum": 4}]) - async def test_int_sum_math_on_models_with_validators(self) -> None: - await ValidatorModel.create(max_value=4) - await ValidatorModel.create(max_value=4) +@pytest.mark.asyncio +async def test_decimal_sum_on_models_with_validators(db) -> None: + await ValidatorModel.create(min_value_decimal=2.0) - query = ValidatorModel.annotate(sum=Sum(F("max_value") * F("max_value"))).values("sum") - result = await query - self.assertEqual(result, [{"sum": 32}]) + query = ValidatorModel.annotate(sum=Sum("min_value_decimal")).values("sum") + result = await query + assert result == [{"sum": Decimal("2.0")}] - async def test_decimal_sum_on_models_with_validators(self) -> None: - await ValidatorModel.create(min_value_decimal=2.0) - query = ValidatorModel.annotate(sum=Sum("min_value_decimal")).values("sum") - result = await query - self.assertEqual(result, [{"sum": Decimal("2.0")}]) +@pytest.mark.asyncio +async def test_decimal_sum_with_math_on_models_with_validators(db) -> None: + await ValidatorModel.create(min_value_decimal=2.0) - async def test_decimal_sum_with_math_on_models_with_validators(self) -> None: - await ValidatorModel.create(min_value_decimal=2.0) + query = ValidatorModel.annotate( + sum=Sum(F("min_value_decimal") - F("min_value_decimal") * F("min_value_decimal")) + ).values("sum") + result = await query + assert result == [{"sum": Decimal("-2.0")}] - query = ValidatorModel.annotate( - sum=Sum(F("min_value_decimal") - F("min_value_decimal") * F("min_value_decimal")) - ).values("sum") - result = await query - self.assertEqual(result, [{"sum": Decimal("-2.0")}]) - async def test_function_requiring_nested_joins(self): - tournament = await Tournament.create(name="Tournament") +@pytest.mark.asyncio +async def test_function_requiring_nested_joins(db): + tournament = await Tournament.create(name="Tournament") - event_first = await Event.create(name="1", tournament=tournament) - event_second = await Event.create(name="2", tournament=tournament) + event_first = await Event.create(name="1", tournament=tournament) + event_second = await Event.create(name="2", tournament=tournament) - team_first = await Team.create(name="First", alias=2) - team_second = await Team.create(name="Second", alias=10) + team_first = await Team.create(name="First", alias=2) + team_second = await Team.create(name="Second", alias=10) - await team_first.events.add(event_first) - await event_second.participants.add(team_second) + await team_first.events.add(event_first) + await event_second.participants.add(team_second) - res = await Tournament.annotate(avg=Avg("events__participants__alias")).values("avg") - self.assertEqual(res, [{"avg": 6}]) + res = await Tournament.annotate(avg=Avg("events__participants__alias")).values("avg") + assert res == [{"avg": 6}] diff --git a/tests/test_basic.py b/tests/test_basic.py index f09f7554f..de2400b59 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,30 +1,36 @@ +import pytest + from tests.testmodels import OldStyleModel, Tournament -from tortoise.contrib import test - - -class TestBasic(test.TestCase): - async def test_basic(self): - tournament = await Tournament.create(name="Test") - await Tournament.filter(id=tournament.id).update(name="Updated name") - saved_event = await Tournament.filter(name="Updated name").first() - self.assertEqual(saved_event.id, tournament.id) - await Tournament(name="Test 2").save() - self.assertEqual( - await Tournament.all().values_list("id", flat=True), - [tournament.id, tournament.id + 1], - ) - self.assertListSortEqual( - await Tournament.all().values("id", "name"), - [ - {"id": tournament.id, "name": "Updated name"}, - {"id": tournament.id + 1, "name": "Test 2"}, - ], - sorted_key="id", - ) - - async def test_basic_oldstyle(self): - obj = await OldStyleModel.create(external_id=123) - assert obj.pk - - assert OldStyleModel._meta.fields_map["id"].pk - assert OldStyleModel._meta.fields_map["external_id"].index + + +@pytest.mark.asyncio +async def test_basic(db): + """Test basic CRUD operations with Tournament model.""" + tournament = await Tournament.create(name="Test") + await Tournament.filter(id=tournament.id).update(name="Updated name") + saved_event = await Tournament.filter(name="Updated name").first() + assert saved_event.id == tournament.id + + await Tournament(name="Test 2").save() + assert await Tournament.all().values_list("id", flat=True) == [ + tournament.id, + tournament.id + 1, + ] + + # Compare sorted by id to ensure consistent ordering + result = await Tournament.all().values("id", "name") + expected = [ + {"id": tournament.id, "name": "Updated name"}, + {"id": tournament.id + 1, "name": "Test 2"}, + ] + assert sorted(result, key=lambda x: x["id"]) == sorted(expected, key=lambda x: x["id"]) + + +@pytest.mark.asyncio +async def test_basic_oldstyle(db): + """Test OldStyleModel with external_id field.""" + obj = await OldStyleModel.create(external_id=123) + assert obj.pk + + assert OldStyleModel._meta.fields_map["id"].pk + assert OldStyleModel._meta.fields_map["external_id"].index diff --git a/tests/test_bulk.py b/tests/test_bulk.py index d5392e96c..52a629f42 100644 --- a/tests/test_bulk.py +++ b/tests/test_bulk.py @@ -1,165 +1,209 @@ from uuid import UUID, uuid4 +import pytest + from tests.testmodels import UniqueName, UUIDPkModel -from tortoise.contrib import test +from tortoise.contrib.test import requireCapability from tortoise.contrib.test.condition import NotEQ from tortoise.exceptions import IntegrityError from tortoise.transactions import in_transaction -class TestBulk(test.TruncationTestCase): - async def test_bulk_create(self): - await UniqueName.bulk_create([UniqueName() for _ in range(1000)]) - all_ = await UniqueName.all().values("id", "name") +def assert_list_sort_equal(actual, expected, sorted_key="id"): + """Assert two lists are equal after sorting by the given key.""" + assert sorted(actual, key=lambda x: x[sorted_key]) == sorted( + expected, key=lambda x: x[sorted_key] + ) + + +@pytest.mark.asyncio +async def test_bulk_create(db_truncate): + """Test basic bulk create operation.""" + await UniqueName.bulk_create([UniqueName() for _ in range(1000)]) + all_ = await UniqueName.all().values("id", "name") + inc = all_[0]["id"] + assert_list_sort_equal( + all_, + [{"id": val + inc, "name": None} for val in range(1000)], + sorted_key="id", + ) + + +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_bulk_create_update_fields(db_truncate): + """Test bulk create with update_fields on conflict.""" + await UniqueName.bulk_create([UniqueName(name="name")]) + await UniqueName.bulk_create( + [UniqueName(name="name", optional="optional")], + update_fields=["optional"], + on_conflict=["name"], + ) + all_ = await UniqueName.all().values("name", "optional") + assert all_ == [{"name": "name", "optional": "optional"}] + + +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_bulk_create_more_that_one_update_fields(db_truncate): + """Test bulk create with multiple update_fields on conflict.""" + await UniqueName.bulk_create([UniqueName(name="name")]) + await UniqueName.bulk_create( + [UniqueName(name="name", optional="optional", other_optional="other_optional")], + update_fields=["optional", "other_optional"], + on_conflict=["name"], + ) + all_ = await UniqueName.all().values("name", "optional", "other_optional") + assert all_ == [ + { + "name": "name", + "optional": "optional", + "other_optional": "other_optional", + } + ] + + +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_bulk_create_with_batch_size(db_truncate): + """Test bulk create with batch_size parameter.""" + await UniqueName.bulk_create([UniqueName(id=id_ + 1) for id_ in range(1000)], batch_size=100) + all_ = await UniqueName.all().values("id", "name") + assert_list_sort_equal( + all_, + [{"id": val + 1, "name": None} for val in range(1000)], + sorted_key="id", + ) + + +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_bulk_create_with_specified(db_truncate): + """Test bulk create with specified IDs.""" + await UniqueName.bulk_create([UniqueName(id=id_) for id_ in range(1000, 2000)]) + all_ = await UniqueName.all().values("id", "name") + assert_list_sort_equal( + all_, + [{"id": id_, "name": None} for id_ in range(1000, 2000)], + sorted_key="id", + ) + + +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_bulk_create_mix_specified(db_truncate): + """Test bulk create with mix of specified and auto-generated IDs.""" + predefined_start = 40000 + predefined_end = 40150 + undefined_count = 100 + + await UniqueName.bulk_create( + [UniqueName(id=id_) for id_ in range(predefined_start, predefined_end)] + + [UniqueName() for _ in range(undefined_count)] + ) + + all_ = await UniqueName.all().order_by("id").values("id", "name") + predefined_count = predefined_end - predefined_start + assert len(all_) == (predefined_count + undefined_count) + + if all_[0]["id"] == predefined_start: + assert sorted(all_[:predefined_count], key=lambda x: x["id"]) == [ + {"id": id_, "name": None} for id_ in range(predefined_start, predefined_end) + ] + inc = all_[predefined_count]["id"] + assert sorted(all_[predefined_count:], key=lambda x: x["id"]) == [ + {"id": val + inc, "name": None} for val in range(undefined_count) + ] + else: inc = all_[0]["id"] - self.assertListSortEqual( - all_, - [{"id": val + inc, "name": None} for val in range(1000)], - sorted_key="id", - ) - - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_bulk_create_update_fields(self): - await UniqueName.bulk_create([UniqueName(name="name")]) - await UniqueName.bulk_create( - [UniqueName(name="name", optional="optional")], - update_fields=["optional"], - on_conflict=["name"], - ) - all_ = await UniqueName.all().values("name", "optional") - self.assertListSortEqual(all_, [{"name": "name", "optional": "optional"}]) + assert sorted(all_[:undefined_count], key=lambda x: x["id"]) == [ + {"id": val + inc, "name": None} for val in range(undefined_count) + ] + assert sorted(all_[undefined_count:], key=lambda x: x["id"]) == [ + {"id": id_, "name": None} for id_ in range(predefined_start, predefined_end) + ] + + +@pytest.mark.asyncio +async def test_bulk_create_uuidpk(db_truncate): + """Test bulk create with UUID primary key model.""" + await UUIDPkModel.bulk_create([UUIDPkModel() for _ in range(1000)]) + res = await UUIDPkModel.all().values_list("id", flat=True) + assert len(res) == 1000 + assert isinstance(res[0], UUID) + + +@requireCapability(supports_transactions=True) +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_bulk_create_in_transaction(db_truncate): + """Test bulk create inside transaction.""" + async with in_transaction(): + await UniqueName.bulk_create([UniqueName() for _ in range(1000)]) + all_ = await UniqueName.all().order_by("id").values("id", "name") + inc = all_[0]["id"] + assert all_ == [{"id": val + inc, "name": None} for val in range(1000)] - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_bulk_create_more_that_one_update_fields(self): - await UniqueName.bulk_create([UniqueName(name="name")]) - await UniqueName.bulk_create( - [UniqueName(name="name", optional="optional", other_optional="other_optional")], - update_fields=["optional", "other_optional"], - on_conflict=["name"], - ) - all_ = await UniqueName.all().values("name", "optional", "other_optional") - self.assertListSortEqual( - all_, - [ - { - "name": "name", - "optional": "optional", - "other_optional": "other_optional", - } - ], - ) - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_bulk_create_with_batch_size(self): - await UniqueName.bulk_create( - [UniqueName(id=id_ + 1) for id_ in range(1000)], batch_size=100 - ) - all_ = await UniqueName.all().values("id", "name") - self.assertListSortEqual( - all_, - [{"id": val + 1, "name": None} for val in range(1000)], - sorted_key="id", - ) - - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_bulk_create_with_specified(self): - await UniqueName.bulk_create([UniqueName(id=id_) for id_ in range(1000, 2000)]) - all_ = await UniqueName.all().values("id", "name") - self.assertListSortEqual( - all_, - [{"id": id_, "name": None} for id_ in range(1000, 2000)], - sorted_key="id", - ) +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_bulk_create_uuidpk_in_transaction(db_truncate): + """Test bulk create with UUID PK inside transaction.""" + async with in_transaction(): + await UUIDPkModel.bulk_create([UUIDPkModel() for _ in range(1000)]) + res = await UUIDPkModel.all().values_list("id", flat=True) + assert len(res) == 1000 + assert isinstance(res[0], UUID) - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_bulk_create_mix_specified(self): - predefined_start = 40000 - predefined_end = 40150 - undefined_count = 100 +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_bulk_create_fail(db_truncate): + """Test bulk create fails with duplicate names.""" + with pytest.raises(IntegrityError): await UniqueName.bulk_create( - [UniqueName(id=id_) for id_ in range(predefined_start, predefined_end)] - + [UniqueName() for _ in range(undefined_count)] + [UniqueName(name=str(i)) for i in range(10)] + + [UniqueName(name=str(i)) for i in range(10)] ) - all_ = await UniqueName.all().order_by("id").values("id", "name") - predefined_count = predefined_end - predefined_start - assert len(all_) == (predefined_count + undefined_count) - - if all_[0]["id"] == predefined_start: - assert sorted(all_[:predefined_count], key=lambda x: x["id"]) == [ - {"id": id_, "name": None} for id_ in range(predefined_start, predefined_end) - ] - inc = all_[predefined_count]["id"] - assert sorted(all_[predefined_count:], key=lambda x: x["id"]) == [ - {"id": val + inc, "name": None} for val in range(undefined_count) - ] - else: - inc = all_[0]["id"] - assert sorted(all_[:undefined_count], key=lambda x: x["id"]) == [ - {"id": val + inc, "name": None} for val in range(undefined_count) - ] - assert sorted(all_[undefined_count:], key=lambda x: x["id"]) == [ - {"id": id_, "name": None} for id_ in range(predefined_start, predefined_end) - ] - - async def test_bulk_create_uuidpk(self): - await UUIDPkModel.bulk_create([UUIDPkModel() for _ in range(1000)]) - res = await UUIDPkModel.all().values_list("id", flat=True) - self.assertEqual(len(res), 1000) - self.assertIsInstance(res[0], UUID) - @test.requireCapability(supports_transactions=True) - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_bulk_create_in_transaction(self): - async with in_transaction(): - await UniqueName.bulk_create([UniqueName() for _ in range(1000)]) - all_ = await UniqueName.all().order_by("id").values("id", "name") - inc = all_[0]["id"] - self.assertEqual(all_, [{"id": val + inc, "name": None} for val in range(1000)]) +@pytest.mark.asyncio +async def test_bulk_create_uuidpk_fail(db_truncate): + """Test bulk create fails with duplicate UUID PKs.""" + val = uuid4() + with pytest.raises(IntegrityError): + await UUIDPkModel.bulk_create([UUIDPkModel(id=val) for _ in range(10)]) + - @test.requireCapability(supports_transactions=True) - async def test_bulk_create_uuidpk_in_transaction(self): +@requireCapability(supports_transactions=True, dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_bulk_create_in_transaction_fail(db_truncate): + """Test bulk create fails inside transaction with duplicates.""" + with pytest.raises(IntegrityError): async with in_transaction(): - await UUIDPkModel.bulk_create([UUIDPkModel() for _ in range(1000)]) - res = await UUIDPkModel.all().values_list("id", flat=True) - self.assertEqual(len(res), 1000) - self.assertIsInstance(res[0], UUID) - - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_bulk_create_fail(self): - with self.assertRaises(IntegrityError): await UniqueName.bulk_create( [UniqueName(name=str(i)) for i in range(10)] + [UniqueName(name=str(i)) for i in range(10)] ) - async def test_bulk_create_uuidpk_fail(self): - val = uuid4() - with self.assertRaises(IntegrityError): + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_bulk_create_uuidpk_in_transaction_fail(db_truncate): + """Test bulk create with UUID PK fails in transaction with duplicates.""" + val = uuid4() + with pytest.raises(IntegrityError): + async with in_transaction(): await UUIDPkModel.bulk_create([UUIDPkModel(id=val) for _ in range(10)]) - @test.requireCapability(supports_transactions=True, dialect=NotEQ("mssql")) - async def test_bulk_create_in_transaction_fail(self): - with self.assertRaises(IntegrityError): - async with in_transaction(): - await UniqueName.bulk_create( - [UniqueName(name=str(i)) for i in range(10)] - + [UniqueName(name=str(i)) for i in range(10)] - ) - - @test.requireCapability(supports_transactions=True) - async def test_bulk_create_uuidpk_in_transaction_fail(self): - val = uuid4() - with self.assertRaises(IntegrityError): - async with in_transaction(): - await UUIDPkModel.bulk_create([UUIDPkModel(id=val) for _ in range(10)]) - - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_bulk_create_ignore_conflicts(self): - name1 = UniqueName(name="name1") - name2 = UniqueName(name="name2") + +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_bulk_create_ignore_conflicts(db_truncate): + """Test bulk create with ignore_conflicts option.""" + name1 = UniqueName(name="name1") + name2 = UniqueName(name="name2") + await UniqueName.bulk_create([name1, name2]) + await UniqueName.bulk_create([name1, name2], ignore_conflicts=True) + with pytest.raises(IntegrityError): await UniqueName.bulk_create([name1, name2]) - await UniqueName.bulk_create([name1, name2], ignore_conflicts=True) - with self.assertRaises(IntegrityError): - await UniqueName.bulk_create([name1, name2]) diff --git a/tests/test_callable_default.py b/tests/test_callable_default.py index 1d0075065..360681052 100644 --- a/tests/test_callable_default.py +++ b/tests/test_callable_default.py @@ -1,21 +1,26 @@ +import pytest + from tests import testmodels -from tortoise.contrib import test -class TestCallableDefault(test.TestCase): - async def test_default_create(self): - model = await testmodels.CallableDefault.create() - self.assertEqual(model.callable_default, "callable_default") - self.assertEqual(model.async_default, "async_callable_default") +@pytest.mark.asyncio +async def test_default_create(db): + model = await testmodels.CallableDefault.create() + assert model.callable_default == "callable_default" + assert model.async_default == "async_callable_default" + + +@pytest.mark.asyncio +async def test_default_by_save(db): + saved_model = testmodels.CallableDefault() + await saved_model.save() + assert saved_model.callable_default == "callable_default" + assert saved_model.async_default == "async_callable_default" - async def test_default_by_save(self): - saved_model = testmodels.CallableDefault() - await saved_model.save() - self.assertEqual(saved_model.callable_default, "callable_default") - self.assertEqual(saved_model.async_default, "async_callable_default") - async def test_async_default_change(self): - default_change = testmodels.CallableDefault() - default_change.async_default = "changed" - await default_change.save() - self.assertEqual(default_change.async_default, "changed") +@pytest.mark.asyncio +async def test_async_default_change(db): + default_change = testmodels.CallableDefault() + default_change.async_default = "changed" + await default_change.save() + assert default_change.async_default == "changed" diff --git a/tests/test_case_when.py b/tests/test_case_when.py index f696b9930..074a5169c 100644 --- a/tests/test_case_when.py +++ b/tests/test_case_when.py @@ -1,236 +1,265 @@ +import pytest +import pytest_asyncio + from tests.testmodels import IntFields -from tortoise import connections -from tortoise.contrib import test from tortoise.exceptions import FieldError from tortoise.expressions import Case, F, Q, When from tortoise.functions import Coalesce, Count -class TestCaseWhen(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - self.intfields = [await IntFields.create(intnum=val) for val in range(10)] - self.db = connections.get("models") - - async def test_single_when(self): - category = Case(When(intnum__gte=8, then="big"), default="default") - sql = ( - IntFields.all() - .annotate(category=category) - .values("intnum", "category") - .sql(params_inline=True) - ) - - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 'big' ELSE 'default' END `category` FROM `intfields`" - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN \'big\' ELSE \'default\' END "category" FROM "intfields"' - self.assertEqual(sql, expected_sql) - - async def test_multi_when(self): - category = Case( - When(intnum__gte=8, then="big"), When(intnum__lte=2, then="small"), default="default" - ) - sql = ( - IntFields.all() - .annotate(category=category) - .values("intnum", "category") - .sql(params_inline=True) - ) - - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 'big' WHEN `intnum`<=2 THEN 'small' ELSE 'default' END `category` FROM `intfields`" - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN \'big\' WHEN "intnum"<=2 THEN \'small\' ELSE \'default\' END "category" FROM "intfields"' - self.assertEqual(sql, expected_sql) - - async def test_q_object_when(self): - category = Case(When(Q(intnum__gt=2, intnum__lt=8), then="middle"), default="default") - sql = ( - IntFields.all() - .annotate(category=category) - .values("intnum", "category") - .sql(params_inline=True) - ) - - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>2 AND `intnum`<8 THEN 'middle' ELSE 'default' END `category` FROM `intfields`" - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">2 AND "intnum"<8 THEN \'middle\' ELSE \'default\' END "category" FROM "intfields"' - self.assertEqual(sql, expected_sql) - - async def test_F_then(self): - category = Case(When(intnum__gte=8, then=F("intnum_null")), default="default") - sql = ( - IntFields.all() - .annotate(category=category) - .values("intnum", "category") - .sql(params_inline=True) - ) - - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN `intnum_null` ELSE 'default' END `category` FROM `intfields`" - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN "intnum_null" ELSE \'default\' END "category" FROM "intfields"' - self.assertEqual(sql, expected_sql) - - async def test_AE_then(self): - # AE: ArithmeticExpression - category = Case(When(intnum__gte=8, then=F("intnum") + 1), default="default") - sql = ( - IntFields.all() - .annotate(category=category) - .values("intnum", "category") - .sql(params_inline=True) - ) - - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN `intnum`+1 ELSE 'default' END `category` FROM `intfields`" - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN "intnum"+1 ELSE \'default\' END "category" FROM "intfields"' - self.assertEqual(sql, expected_sql) - - async def test_func_then(self): - category = Case(When(intnum__gte=8, then=Coalesce("intnum_null", 10)), default="default") - sql = ( - IntFields.all() - .annotate(category=category) - .values("intnum", "category") - .sql(params_inline=True) - ) - - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN COALESCE(`intnum_null`,10) ELSE 'default' END `category` FROM `intfields`" - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN COALESCE("intnum_null",10) ELSE \'default\' END "category" FROM "intfields"' - self.assertEqual(sql, expected_sql) - - async def test_F_default(self): - category = Case(When(intnum__gte=8, then="big"), default=F("intnum_null")) - sql = ( - IntFields.all() - .annotate(category=category) - .values("intnum", "category") - .sql(params_inline=True) - ) - - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 'big' ELSE `intnum_null` END `category` FROM `intfields`" - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN \'big\' ELSE "intnum_null" END "category" FROM "intfields"' - self.assertEqual(sql, expected_sql) - - async def test_AE_default(self): - # AE: ArithmeticExpression - category = Case(When(intnum__gte=8, then=8), default=F("intnum") + 1) - sql = ( - IntFields.all() - .annotate(category=category) - .values("intnum", "category") - .sql(params_inline=True) - ) - - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 8 ELSE `intnum`+1 END `category` FROM `intfields`" - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN 8 ELSE "intnum"+1 END "category" FROM "intfields"' - self.assertEqual(sql, expected_sql) - - async def test_func_default(self): - category = Case(When(intnum__gte=8, then=8), default=Coalesce("intnum_null", 10)) - sql = ( - IntFields.all() - .annotate(category=category) - .values("intnum", "category") - .sql(params_inline=True) - ) - - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 8 ELSE COALESCE(`intnum_null`,10) END `category` FROM `intfields`" - else: - expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN 8 ELSE COALESCE("intnum_null",10) END "category" FROM "intfields"' - self.assertEqual(sql, expected_sql) - - async def test_case_when_in_where(self): - category = Case( - When(intnum__gte=8, then="big"), When(intnum__lte=2, then="small"), default="middle" - ) - sql = ( - IntFields.all() - .annotate(category=category) - .filter(category__in=["big", "small"]) - .values("intnum") - .sql(params_inline=True) - ) - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `intnum` `intnum` FROM `intfields` WHERE CASE WHEN `intnum`>=8 THEN 'big' WHEN `intnum`<=2 THEN 'small' ELSE 'middle' END IN ('big','small')" - else: - expected_sql = "SELECT \"intnum\" \"intnum\" FROM \"intfields\" WHERE CASE WHEN \"intnum\">=8 THEN 'big' WHEN \"intnum\"<=2 THEN 'small' ELSE 'middle' END IN ('big','small')" - self.assertEqual(sql, expected_sql) - - async def test_annotation_in_when_annotation(self): - sql = ( - IntFields.all() - .annotate(intnum_plus_1=F("intnum") + 1) - .annotate(bigger_than_10=Case(When(Q(intnum_plus_1__gte=10), then=True), default=False)) - .values("id", "intnum", "intnum_plus_1", "bigger_than_10") - .sql(params_inline=True) - ) - - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `id` `id`,`intnum` `intnum`,`intnum`+1 `intnum_plus_1`,CASE WHEN `intnum`+1>=10 THEN true ELSE false END `bigger_than_10` FROM `intfields`" - else: - expected_sql = 'SELECT "id" "id","intnum" "intnum","intnum"+1 "intnum_plus_1",CASE WHEN "intnum"+1>=10 THEN true ELSE false END "bigger_than_10" FROM "intfields"' - self.assertEqual(sql, expected_sql) - - async def test_func_annotation_in_when_annotation(self): - sql = ( - IntFields.all() - .annotate(intnum_col=Coalesce("intnum", 0)) - .annotate(is_zero=Case(When(Q(intnum_col=0), then=True), default=False)) - .values("id", "intnum_col", "is_zero") - .sql(params_inline=True) - ) - - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT `id` `id`,COALESCE(`intnum`,0) `intnum_col`,CASE WHEN COALESCE(`intnum`,0)=0 THEN true ELSE false END `is_zero` FROM `intfields`" - else: - expected_sql = 'SELECT "id" "id",COALESCE("intnum",0) "intnum_col",CASE WHEN COALESCE("intnum",0)=0 THEN true ELSE false END "is_zero" FROM "intfields"' - self.assertEqual(sql, expected_sql) - - async def test_case_when_in_group_by(self): - sql = ( - IntFields.all() - .annotate(is_zero=Case(When(Q(intnum=0), then=True), default=False)) - .annotate(count=Count("id")) - .group_by("is_zero") - .values("is_zero", "count") - .sql(params_inline=True) - ) - - dialect = self.db.schema_generator.DIALECT - if dialect == "mysql": - expected_sql = "SELECT CASE WHEN `intnum`=0 THEN true ELSE false END `is_zero`,COUNT(`id`) `count` FROM `intfields` GROUP BY `is_zero`" - elif dialect == "mssql": - expected_sql = 'SELECT CASE WHEN "intnum"=0 THEN true ELSE false END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY CASE WHEN "intnum"=0 THEN true ELSE false END' - else: - expected_sql = 'SELECT CASE WHEN "intnum"=0 THEN true ELSE false END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY "is_zero"' - self.assertEqual(sql, expected_sql) - - async def test_unknown_field_in_when_annotation(self): - with self.assertRaisesRegex(FieldError, "Unknown filter param 'unknown'.+"): - IntFields.all().annotate(intnum_col=Coalesce("intnum", 0)).annotate( - is_zero=Case(When(Q(unknown=0), then="1"), default="2") - ).sql(params_inline=True) +@pytest_asyncio.fixture +async def intfields_data(db): + """Create IntFields test data.""" + intfields = [await IntFields.create(intnum=val) for val in range(10)] + return intfields + + +@pytest.mark.asyncio +async def test_single_when(db, intfields_data): + category = Case(When(intnum__gte=8, then="big"), default="default") + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = db.db().schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 'big' ELSE 'default' END `category` FROM `intfields`" + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN \'big\' ELSE \'default\' END "category" FROM "intfields"' + assert sql == expected_sql + + +@pytest.mark.asyncio +async def test_multi_when(db, intfields_data): + category = Case( + When(intnum__gte=8, then="big"), When(intnum__lte=2, then="small"), default="default" + ) + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = db.db().schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 'big' WHEN `intnum`<=2 THEN 'small' ELSE 'default' END `category` FROM `intfields`" + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN \'big\' WHEN "intnum"<=2 THEN \'small\' ELSE \'default\' END "category" FROM "intfields"' + assert sql == expected_sql + + +@pytest.mark.asyncio +async def test_q_object_when(db, intfields_data): + category = Case(When(Q(intnum__gt=2, intnum__lt=8), then="middle"), default="default") + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = db.db().schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>2 AND `intnum`<8 THEN 'middle' ELSE 'default' END `category` FROM `intfields`" + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">2 AND "intnum"<8 THEN \'middle\' ELSE \'default\' END "category" FROM "intfields"' + assert sql == expected_sql + + +@pytest.mark.asyncio +async def test_F_then(db, intfields_data): + category = Case(When(intnum__gte=8, then=F("intnum_null")), default="default") + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = db.db().schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN `intnum_null` ELSE 'default' END `category` FROM `intfields`" + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN "intnum_null" ELSE \'default\' END "category" FROM "intfields"' + assert sql == expected_sql + + +@pytest.mark.asyncio +async def test_AE_then(db, intfields_data): + # AE: ArithmeticExpression + category = Case(When(intnum__gte=8, then=F("intnum") + 1), default="default") + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = db.db().schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN `intnum`+1 ELSE 'default' END `category` FROM `intfields`" + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN "intnum"+1 ELSE \'default\' END "category" FROM "intfields"' + assert sql == expected_sql + + +@pytest.mark.asyncio +async def test_func_then(db, intfields_data): + category = Case(When(intnum__gte=8, then=Coalesce("intnum_null", 10)), default="default") + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = db.db().schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN COALESCE(`intnum_null`,10) ELSE 'default' END `category` FROM `intfields`" + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN COALESCE("intnum_null",10) ELSE \'default\' END "category" FROM "intfields"' + assert sql == expected_sql + + +@pytest.mark.asyncio +async def test_F_default(db, intfields_data): + category = Case(When(intnum__gte=8, then="big"), default=F("intnum_null")) + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = db.db().schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 'big' ELSE `intnum_null` END `category` FROM `intfields`" + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN \'big\' ELSE "intnum_null" END "category" FROM "intfields"' + assert sql == expected_sql + + +@pytest.mark.asyncio +async def test_AE_default(db, intfields_data): + # AE: ArithmeticExpression + category = Case(When(intnum__gte=8, then=8), default=F("intnum") + 1) + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = db.db().schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 8 ELSE `intnum`+1 END `category` FROM `intfields`" + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN 8 ELSE "intnum"+1 END "category" FROM "intfields"' + assert sql == expected_sql + + +@pytest.mark.asyncio +async def test_func_default(db, intfields_data): + category = Case(When(intnum__gte=8, then=8), default=Coalesce("intnum_null", 10)) + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) + + dialect = db.db().schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum`,CASE WHEN `intnum`>=8 THEN 8 ELSE COALESCE(`intnum_null`,10) END `category` FROM `intfields`" + else: + expected_sql = 'SELECT "intnum" "intnum",CASE WHEN "intnum">=8 THEN 8 ELSE COALESCE("intnum_null",10) END "category" FROM "intfields"' + assert sql == expected_sql + + +@pytest.mark.asyncio +async def test_case_when_in_where(db, intfields_data): + category = Case( + When(intnum__gte=8, then="big"), When(intnum__lte=2, then="small"), default="middle" + ) + sql = ( + IntFields.all() + .annotate(category=category) + .filter(category__in=["big", "small"]) + .values("intnum") + .sql(params_inline=True) + ) + dialect = db.db().schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `intnum` `intnum` FROM `intfields` WHERE CASE WHEN `intnum`>=8 THEN 'big' WHEN `intnum`<=2 THEN 'small' ELSE 'middle' END IN ('big','small')" + else: + expected_sql = "SELECT \"intnum\" \"intnum\" FROM \"intfields\" WHERE CASE WHEN \"intnum\">=8 THEN 'big' WHEN \"intnum\"<=2 THEN 'small' ELSE 'middle' END IN ('big','small')" + assert sql == expected_sql + + +@pytest.mark.asyncio +async def test_annotation_in_when_annotation(db, intfields_data): + sql = ( + IntFields.all() + .annotate(intnum_plus_1=F("intnum") + 1) + .annotate(bigger_than_10=Case(When(Q(intnum_plus_1__gte=10), then=True), default=False)) + .values("id", "intnum", "intnum_plus_1", "bigger_than_10") + .sql(params_inline=True) + ) + + dialect = db.db().schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `id` `id`,`intnum` `intnum`,`intnum`+1 `intnum_plus_1`,CASE WHEN `intnum`+1>=10 THEN true ELSE false END `bigger_than_10` FROM `intfields`" + else: + expected_sql = 'SELECT "id" "id","intnum" "intnum","intnum"+1 "intnum_plus_1",CASE WHEN "intnum"+1>=10 THEN true ELSE false END "bigger_than_10" FROM "intfields"' + assert sql == expected_sql + + +@pytest.mark.asyncio +async def test_func_annotation_in_when_annotation(db, intfields_data): + sql = ( + IntFields.all() + .annotate(intnum_col=Coalesce("intnum", 0)) + .annotate(is_zero=Case(When(Q(intnum_col=0), then=True), default=False)) + .values("id", "intnum_col", "is_zero") + .sql(params_inline=True) + ) + + dialect = db.db().schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT `id` `id`,COALESCE(`intnum`,0) `intnum_col`,CASE WHEN COALESCE(`intnum`,0)=0 THEN true ELSE false END `is_zero` FROM `intfields`" + else: + expected_sql = 'SELECT "id" "id",COALESCE("intnum",0) "intnum_col",CASE WHEN COALESCE("intnum",0)=0 THEN true ELSE false END "is_zero" FROM "intfields"' + assert sql == expected_sql + + +@pytest.mark.asyncio +async def test_case_when_in_group_by(db, intfields_data): + sql = ( + IntFields.all() + .annotate(is_zero=Case(When(Q(intnum=0), then=True), default=False)) + .annotate(count=Count("id")) + .group_by("is_zero") + .values("is_zero", "count") + .sql(params_inline=True) + ) + + dialect = db.db().schema_generator.DIALECT + if dialect == "mysql": + expected_sql = "SELECT CASE WHEN `intnum`=0 THEN true ELSE false END `is_zero`,COUNT(`id`) `count` FROM `intfields` GROUP BY `is_zero`" + elif dialect == "mssql": + expected_sql = 'SELECT CASE WHEN "intnum"=0 THEN true ELSE false END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY CASE WHEN "intnum"=0 THEN true ELSE false END' + else: + expected_sql = 'SELECT CASE WHEN "intnum"=0 THEN true ELSE false END "is_zero",COUNT("id") "count" FROM "intfields" GROUP BY "is_zero"' + assert sql == expected_sql + + +@pytest.mark.asyncio +async def test_unknown_field_in_when_annotation(db, intfields_data): + with pytest.raises(FieldError, match="Unknown filter param 'unknown'.+"): + IntFields.all().annotate(intnum_col=Coalesce("intnum", 0)).annotate( + is_zero=Case(When(Q(unknown=0), then="1"), default="2") + ).sql(params_inline=True) diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index cc565b677..f78d694b3 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -1,115 +1,172 @@ import asyncio import sys +import pytest + from tests.testmodels import Tournament, UniqueName -from tortoise import Tortoise, connections -from tortoise.contrib import test +from tortoise import connections +from tortoise.contrib.test import requireCapability from tortoise.contrib.test.condition import NotEQ from tortoise.transactions import in_transaction - -class TestConcurrencyIsolated(test.IsolatedTestCase): - async def test_concurrency_read(self): - await Tournament.create(name="Test") - tour1 = await Tournament.first() - all_read = await asyncio.gather(*[Tournament.first() for _ in range(100)]) - self.assertEqual(all_read, [tour1 for _ in range(100)]) - - async def test_concurrency_create(self): - all_write = await asyncio.gather(*[Tournament.create(name="Test") for _ in range(100)]) - all_read = await Tournament.all() - self.assertEqual(set(all_write), set(all_read)) - - async def test_nonconcurrent_get_or_create(self): - unas = [await UniqueName.get_or_create(name="c") for _ in range(10)] - una_created = [una[1] for una in unas if una[1] is True] - self.assertEqual(len(una_created), 1) - for una in unas: - self.assertEqual(una[0], unas[0][0]) - - @test.skipIf(sys.version_info < (3, 7), "aiocontextvars backport not handling this well") - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_concurrent_get_or_create(self): - unas = await asyncio.gather(*[UniqueName.get_or_create(name="d") for _ in range(10)]) - una_created = [una[1] for una in unas if una[1] is True] - self.assertEqual(len(una_created), 1) - for una in unas: - self.assertEqual(una[0], unas[0][0]) - - @test.skipIf(sys.version_info < (3, 7), "aiocontextvars backport not handling this well") - @test.requireCapability(supports_transactions=True) - async def test_concurrent_transactions_with_multiple_ops(self): - async def create_in_transaction(): +# ============================================================================= +# TestConcurrencyIsolated - uses db_isolated fixture +# ============================================================================= + + +@pytest.mark.asyncio +async def test_concurrency_read_isolated(db_isolated): + """Test concurrent reads.""" + await Tournament.create(name="Test") + tour1 = await Tournament.first() + all_read = await asyncio.gather(*[Tournament.first() for _ in range(100)]) + assert all_read == [tour1 for _ in range(100)] + + +@pytest.mark.asyncio +async def test_concurrency_create_isolated(db_isolated): + """Test concurrent creates.""" + all_write = await asyncio.gather(*[Tournament.create(name="Test") for _ in range(100)]) + all_read = await Tournament.all() + assert set(all_write) == set(all_read) + + +@pytest.mark.asyncio +async def test_nonconcurrent_get_or_create_isolated(db_isolated): + """Test non-concurrent get_or_create.""" + unas = [await UniqueName.get_or_create(name="c") for _ in range(10)] + una_created = [una[1] for una in unas if una[1] is True] + assert len(una_created) == 1 + for una in unas: + assert una[0] == unas[0][0] + + +@pytest.mark.skipif( + sys.version_info < (3, 7), reason="aiocontextvars backport not handling this well" +) +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_concurrent_get_or_create_isolated(db_isolated): + """Test concurrent get_or_create.""" + unas = await asyncio.gather(*[UniqueName.get_or_create(name="d") for _ in range(10)]) + una_created = [una[1] for una in unas if una[1] is True] + assert len(una_created) == 1 + for una in unas: + assert una[0] == unas[0][0] + + +@pytest.mark.skipif( + sys.version_info < (3, 7), reason="aiocontextvars backport not handling this well" +) +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_concurrent_transactions_with_multiple_ops(db_isolated): + """Test concurrent transactions with multiple operations.""" + + async def create_in_transaction(): + async with in_transaction(): + await asyncio.gather(*[Tournament.create(name="Test") for _ in range(100)]) + + await asyncio.gather(*[create_in_transaction() for _ in range(10)]) + count = await Tournament.all().count() + assert count == 1000 + + +@pytest.mark.skipif( + sys.version_info < (3, 7), reason="aiocontextvars backport not handling this well" +) +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_concurrent_transactions_with_single_op(db_isolated): + """Test concurrent transactions with single operation.""" + + async def create(): + async with in_transaction(): + await Tournament.create(name="Test") + + await asyncio.gather(*[create() for _ in range(100)]) + count = await Tournament.all().count() + assert count == 100 + + +@pytest.mark.skipif( + sys.version_info < (3, 7), reason="aiocontextvars backport not handling this well" +) +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_nested_concurrent_transactions_with_multiple_ops(db_isolated): + """Test nested concurrent transactions with multiple operations.""" + + async def create_in_transaction(): + async with in_transaction(): async with in_transaction(): await asyncio.gather(*[Tournament.create(name="Test") for _ in range(100)]) - await asyncio.gather(*[create_in_transaction() for _ in range(10)]) - count = await Tournament.all().count() - self.assertEqual(count, 1000) + await asyncio.gather(*[create_in_transaction() for _ in range(10)]) + count = await Tournament.all().count() + assert count == 1000 - @test.skipIf(sys.version_info < (3, 7), "aiocontextvars backport not handling this well") - @test.requireCapability(supports_transactions=True) - async def test_concurrent_transactions_with_single_op(self): - async def create(): - async with in_transaction(): - await Tournament.create(name="Test") - await asyncio.gather(*[create() for _ in range(100)]) - count = await Tournament.all().count() - self.assertEqual(count, 100) +# ============================================================================= +# TestConcurrencyTransactioned - uses db fixture (transaction rollback) +# ============================================================================= - @test.skipIf(sys.version_info < (3, 7), "aiocontextvars backport not handling this well") - @test.requireCapability(supports_transactions=True) - async def test_nested_concurrent_transactions_with_multiple_ops(self): - async def create_in_transaction(): - async with in_transaction(): - async with in_transaction(): - await asyncio.gather(*[Tournament.create(name="Test") for _ in range(100)]) - - await asyncio.gather(*[create_in_transaction() for _ in range(10)]) - count = await Tournament.all().count() - self.assertEqual(count, 1000) - - -@test.requireCapability(supports_transactions=True) -class TestConcurrencyTransactioned(test.TestCase): - async def test_concurrency_read(self): - await Tournament.create(name="Test") - tour1 = await Tournament.first() - all_read = await asyncio.gather(*[Tournament.first() for _ in range(100)]) - self.assertEqual(all_read, [tour1 for _ in range(100)]) - - async def test_concurrency_create(self): - all_write = await asyncio.gather(*[Tournament.create(name="Test") for _ in range(100)]) - all_read = await Tournament.all() - self.assertEqual(set(all_write), set(all_read)) - - async def test_nonconcurrent_get_or_create(self): - unas = [await UniqueName.get_or_create(name="a") for _ in range(10)] - una_created = [una[1] for una in unas if una[1] is True] - self.assertEqual(len(una_created), 1) - for una in unas: - self.assertEqual(una[0], unas[0][0]) - - -class TestConcurrentDBConnectionInitialization(test.IsolatedTestCase): - """Tortoise.init is lazy and does not initialize the database connection until the first query. - These tests ensure that concurrent queries do not cause initialization issues.""" - - async def _setUpDB(self) -> None: - """Override to avoid database connection initialization when generating the schema.""" - await super()._setUpDB() - config = test.getDBConfig(app_label="models", modules=test._MODULES) - await Tortoise.init(config, _create_db=True) - - async def test_concurrent_queries(self): - await asyncio.gather( - *[connections.get("models").execute_query("SELECT 1") for _ in range(100)] - ) - - async def test_concurrent_transactions(self) -> None: - async def transaction() -> None: - async with in_transaction(): - await connections.get("models").execute_query("SELECT 1") - await asyncio.gather(*[transaction() for _ in range(100)]) +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_concurrency_read_transactioned(db): + """Test concurrent reads within transaction.""" + await Tournament.create(name="Test") + tour1 = await Tournament.first() + all_read = await asyncio.gather(*[Tournament.first() for _ in range(100)]) + assert all_read == [tour1 for _ in range(100)] + + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_concurrency_create_transactioned(db): + """Test concurrent creates within transaction.""" + all_write = await asyncio.gather(*[Tournament.create(name="Test") for _ in range(100)]) + all_read = await Tournament.all() + assert set(all_write) == set(all_read) + + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_nonconcurrent_get_or_create_transactioned(db): + """Test non-concurrent get_or_create within transaction.""" + unas = [await UniqueName.get_or_create(name="a") for _ in range(10)] + una_created = [una[1] for una in unas if una[1] is True] + assert len(una_created) == 1 + for una in unas: + assert una[0] == unas[0][0] + + +# ============================================================================= +# TestConcurrentDBConnectionInitialization - tests lazy connection init +# These tests ensure concurrent queries don't cause initialization issues. +# ============================================================================= + + +@pytest.mark.asyncio +async def test_concurrent_queries_lazy_init(db_isolated): + """Test concurrent queries with lazy connection initialization. + + Tortoise.init is lazy and does not initialize the database connection + until the first query. This test ensures that concurrent queries do not + cause initialization issues. + """ + # The db_isolated fixture already initializes the connection, so we just + # test that concurrent queries work + await asyncio.gather(*[connections.get("models").execute_query("SELECT 1") for _ in range(100)]) + + +@pytest.mark.asyncio +async def test_concurrent_transactions_lazy_init(db_isolated): + """Test concurrent transactions with lazy connection initialization.""" + + async def transaction() -> None: + async with in_transaction(): + await connections.get("models").execute_query("SELECT 1") + + await asyncio.gather(*[transaction() for _ in range(100)]) diff --git a/tests/test_connection.py b/tests/test_connection.py index 75e44a151..9a48c76bb 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,287 +1,284 @@ -from contextvars import ContextVar from unittest.mock import AsyncMock, Mock, PropertyMock, call, patch +import pytest + from tortoise import BaseDBAsyncClient, ConfigurationError from tortoise.connection import ConnectionHandler -from tortoise.contrib.test import SimpleTestCase - - -class TestConnections(SimpleTestCase): - def setUp(self) -> None: - self.conn_handler = ConnectionHandler() - - def test_init_constructor(self): - self.assertIsNone(self.conn_handler._db_config) - self.assertFalse(self.conn_handler._create_db) - - @patch("tortoise.connection.ConnectionHandler._init_connections") - async def test_init(self, mocked_init_connections: AsyncMock): - db_config = {"default": {"HOST": "some_host", "PORT": "1234"}} - await self.conn_handler._init(db_config, True) - mocked_init_connections.assert_awaited_once() - self.assertEqual(db_config, self.conn_handler._db_config) - self.assertTrue(self.conn_handler._create_db) - - def test_db_config_present(self): - self.conn_handler._db_config = {"default": {"HOST": "some_host", "PORT": "1234"}} - self.assertEqual(self.conn_handler.db_config, self.conn_handler._db_config) - - def test_db_config_not_present(self): - err_msg = ( - "DB configuration not initialised. Make sure to call " - "Tortoise.init with a valid configuration before attempting " - "to create connections." - ) - with self.assertRaises(ConfigurationError, msg=err_msg): - _ = self.conn_handler.db_config - - @patch("tortoise.connection.ConnectionHandler._conn_storage", spec=ContextVar) - def test_get_storage(self, mocked_conn_storage: Mock): - expected_ret_val = {"default": BaseDBAsyncClient("default")} - mocked_conn_storage.get.return_value = expected_ret_val - ret_val = self.conn_handler._get_storage() - self.assertDictEqual(ret_val, expected_ret_val) - - @patch("tortoise.connection.ConnectionHandler._conn_storage", spec=ContextVar) - def test_set_storage(self, mocked_conn_storage: Mock): - mocked_conn_storage.set.return_value = "blah" - new_storage = {"default": BaseDBAsyncClient("default")} - ret_val = self.conn_handler._set_storage(new_storage) - mocked_conn_storage.set.assert_called_once_with(new_storage) - self.assertEqual(ret_val, mocked_conn_storage.set.return_value) - - @patch("tortoise.connection.ConnectionHandler._get_storage") - @patch("tortoise.connection.copy") - def test_copy_storage(self, mocked_copy: Mock, mocked_get_storage: Mock): - expected_ret_value = {"default": BaseDBAsyncClient("default")} - mocked_get_storage.return_value = expected_ret_value - mocked_copy.return_value = expected_ret_value.copy() - ret_val = self.conn_handler._copy_storage() - mocked_get_storage.assert_called_once() - mocked_copy.assert_called_once_with(mocked_get_storage.return_value) - self.assertDictEqual(ret_val, expected_ret_value) - self.assertNotEqual(id(expected_ret_value), id(ret_val)) - - @patch("tortoise.connection.ConnectionHandler._get_storage") - def test_clear_storage(self, mocked_get_storage: Mock): - self.conn_handler._clear_storage() - mocked_get_storage.assert_called_once() - mocked_get_storage.return_value.clear.assert_called_once() - - @patch("tortoise.connection.importlib.import_module") - def test_discover_client_class_proper_impl(self, mocked_import_module: Mock): - mocked_import_module.return_value = Mock(client_class="some_class") - del mocked_import_module.return_value.get_client_class - client_class = self.conn_handler._discover_client_class({"engine": "blah"}) - - mocked_import_module.assert_called_once_with("blah") - self.assertEqual(client_class, "some_class") - - @patch("tortoise.connection.importlib.import_module") - def test_discover_client_class_improper_impl(self, mocked_import_module: Mock): - del mocked_import_module.return_value.client_class - del mocked_import_module.return_value.get_client_class - engine = "some_engine" - with self.assertRaises( - ConfigurationError, msg=f'Backend for engine "{engine}" does not implement db client' - ): - _ = self.conn_handler._discover_client_class({"engine": engine}) - - @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) - def test_get_db_info_present(self, mocked_db_config: Mock): - expected_ret_val = {"HOST": "some_host", "PORT": "1234"} - mocked_db_config.return_value = {"default": expected_ret_val} - ret_val = self.conn_handler._get_db_info("default") - self.assertEqual(ret_val, expected_ret_val) - - @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) - def test_get_db_info_not_present(self, mocked_db_config: Mock): - mocked_db_config.return_value = {"default": {"HOST": "some_host", "PORT": "1234"}} - conn_alias = "blah" - with self.assertRaises( - ConfigurationError, - msg=f"Unable to get db settings for alias '{conn_alias}'. Please " - f"check if the config dict contains this alias and try again", - ): - _ = self.conn_handler._get_db_info(conn_alias) - - @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) - @patch("tortoise.connection.ConnectionHandler.get") - async def test_init_connections_no_db_create(self, mocked_get: Mock, mocked_db_config: Mock): - conn_1, conn_2 = AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient) - mocked_get.side_effect = [conn_1, conn_2] - mocked_db_config.return_value = { - "default": {"HOST": "some_host", "PORT": "1234"}, - "other": {"HOST": "some_other_host", "PORT": "1234"}, - } - await self.conn_handler._init_connections() - mocked_db_config.assert_called_once() - mocked_get.assert_has_calls([call("default"), call("other")], any_order=True) - conn_1.db_create.assert_not_awaited() - conn_2.db_create.assert_not_awaited() - - @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) - @patch("tortoise.connection.ConnectionHandler.get") - async def test_init_connections_db_create(self, mocked_get: Mock, mocked_db_config: Mock): - self.conn_handler._create_db = True - conn_1, conn_2 = AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient) - mocked_get.side_effect = [conn_1, conn_2] - mocked_db_config.return_value = { - "default": {"HOST": "some_host", "PORT": "1234"}, - "other": {"HOST": "some_other_host", "PORT": "1234"}, - } - await self.conn_handler._init_connections() - mocked_db_config.assert_called_once() - mocked_get.assert_has_calls([call("default"), call("other")], any_order=True) - conn_1.db_create.assert_awaited_once() - conn_2.db_create.assert_awaited_once() - - @patch("tortoise.connection.ConnectionHandler._get_db_info") - @patch("tortoise.connection.expand_db_url") - @patch("tortoise.connection.ConnectionHandler._discover_client_class") - def test_create_connection_db_info_str( - self, - mocked_discover_client_class: Mock, - mocked_expand_db_url: Mock, - mocked_get_db_info: Mock, - ): - alias = "default" - mocked_get_db_info.return_value = "some_db_url" - mocked_expand_db_url.return_value = { - "engine": "some_engine", - "credentials": {"cred_key": "some_val"}, - } - expected_client_class = Mock(return_value="some_connection") - mocked_discover_client_class.return_value = expected_client_class - expected_db_params = {"cred_key": "some_val", "connection_name": alias} - - ret_val = self.conn_handler._create_connection(alias) - - mocked_get_db_info.assert_called_once_with(alias) - mocked_expand_db_url.assert_called_once_with("some_db_url") - mocked_discover_client_class.assert_called_once_with( - {"engine": "some_engine", "credentials": {"cred_key": "some_val"}} - ) - expected_client_class.assert_called_once_with(**expected_db_params) - self.assertEqual(ret_val, "some_connection") - - @patch("tortoise.connection.ConnectionHandler._get_db_info") - @patch("tortoise.connection.expand_db_url") - @patch("tortoise.connection.ConnectionHandler._discover_client_class") - def test_create_connection_db_info_not_str( - self, - mocked_discover_client_class: Mock, - mocked_expand_db_url: Mock, - mocked_get_db_info: Mock, + + +@pytest.fixture +def conn_handler(): + return ConnectionHandler() + + +def test_init_constructor(conn_handler): + assert conn_handler._db_config is None + assert conn_handler._create_db is False + assert conn_handler._storage == {} + + +@pytest.mark.asyncio +@patch("tortoise.connection.ConnectionHandler._init_connections") +async def test_init(mocked_init_connections, conn_handler): + db_config = {"default": {"HOST": "some_host", "PORT": "1234"}} + await conn_handler._init(db_config, True) + mocked_init_connections.assert_awaited_once() + assert db_config == conn_handler._db_config + assert conn_handler._create_db is True + + +def test_db_config_present(conn_handler): + conn_handler._db_config = {"default": {"HOST": "some_host", "PORT": "1234"}} + assert conn_handler.db_config == conn_handler._db_config + + +def test_db_config_not_present(conn_handler): + err_msg = ( + "DB configuration not initialised. Make sure to call " + "Tortoise.init with a valid configuration before attempting " + "to create connections." + ) + with pytest.raises(ConfigurationError, match=err_msg): + _ = conn_handler.db_config + + +def test_get_storage(conn_handler): + expected_ret_val = {"default": BaseDBAsyncClient("default")} + conn_handler._storage = expected_ret_val + ret_val = conn_handler._get_storage() + assert ret_val == expected_ret_val + assert ret_val is conn_handler._storage + + +def test_set_storage(conn_handler): + new_storage = {"default": BaseDBAsyncClient("default")} + conn_handler._set_storage(new_storage) + assert conn_handler._storage == new_storage + assert conn_handler._storage is new_storage + + +def test_copy_storage(conn_handler): + original_storage = {"default": BaseDBAsyncClient("default")} + conn_handler._storage = original_storage + ret_val = conn_handler._copy_storage() + assert ret_val == original_storage + assert ret_val is not original_storage + + +def test_clear_storage(conn_handler): + conn_handler._storage = {"default": BaseDBAsyncClient("default")} + conn_handler._clear_storage() + assert conn_handler._storage == {} + + +@patch("tortoise.connection.importlib.import_module") +def test_discover_client_class_proper_impl(mocked_import_module, conn_handler): + mocked_import_module.return_value = Mock(client_class="some_class") + del mocked_import_module.return_value.get_client_class + client_class = conn_handler._discover_client_class({"engine": "blah"}) + + mocked_import_module.assert_called_once_with("blah") + assert client_class == "some_class" + + +@patch("tortoise.connection.importlib.import_module") +def test_discover_client_class_improper_impl(mocked_import_module, conn_handler): + del mocked_import_module.return_value.client_class + del mocked_import_module.return_value.get_client_class + engine = "some_engine" + with pytest.raises( + ConfigurationError, match=f'Backend for engine "{engine}" does not implement db client' ): - alias = "default" - mocked_get_db_info.return_value = { - "engine": "some_engine", - "credentials": {"cred_key": "some_val"}, - } - expected_client_class = Mock(return_value="some_connection") - mocked_discover_client_class.return_value = expected_client_class - expected_db_params = {"cred_key": "some_val", "connection_name": alias} - - ret_val = self.conn_handler._create_connection(alias) - - mocked_get_db_info.assert_called_once_with(alias) - mocked_expand_db_url.assert_not_called() - mocked_discover_client_class.assert_called_once_with( - {"engine": "some_engine", "credentials": {"cred_key": "some_val"}} - ) - expected_client_class.assert_called_once_with(**expected_db_params) - self.assertEqual(ret_val, "some_connection") - - @patch("tortoise.connection.ConnectionHandler._get_storage") - @patch("tortoise.connection.ConnectionHandler._create_connection") - def test_get_alias_present(self, mocked_create_connection: Mock, mocked_get_storage: Mock): - mocked_get_storage.return_value = {"default": "some_connection"} - ret_val = self.conn_handler.get("default") - mocked_get_storage.assert_called_once() - mocked_create_connection.assert_not_called() - self.assertEqual(ret_val, "some_connection") - - @patch("tortoise.connection.ConnectionHandler._get_storage") - @patch("tortoise.connection.ConnectionHandler._create_connection") - def test_get_alias_not_present(self, mocked_create_connection: Mock, mocked_get_storage: Mock): - mocked_get_storage.return_value = {"default": "some_connection"} - expected_final_dict = {**mocked_get_storage.return_value, "other": "some_other_connection"} - mocked_create_connection.return_value = "some_other_connection" - ret_val = self.conn_handler.get("other") - mocked_get_storage.assert_called_once() - mocked_create_connection.assert_called_once_with("other") - self.assertEqual(ret_val, "some_other_connection") - self.assertDictEqual(mocked_get_storage.return_value, expected_final_dict) - - @patch("tortoise.connection.ConnectionHandler._conn_storage", spec=ContextVar) - @patch("tortoise.connection.ConnectionHandler._copy_storage") - def test_set(self, mocked_copy_storage: Mock, mocked_conn_storage: Mock): - mocked_copy_storage.return_value = {} - expected_storage = {"default": "some_conn"} - mocked_conn_storage.set.return_value = "some_token" - ret_val = self.conn_handler.set("default", "some_conn") # type: ignore - mocked_copy_storage.assert_called_once() - self.assertEqual(ret_val, mocked_conn_storage.set.return_value) - self.assertDictEqual(expected_storage, mocked_copy_storage.return_value) - - @patch("tortoise.connection.ConnectionHandler._get_storage") - def test_discard(self, mocked_get_storage: Mock): - mocked_get_storage.return_value = {"default": "some_conn"} - ret_val = self.conn_handler.discard("default") - self.assertEqual(ret_val, "some_conn") - self.assertDictEqual({}, mocked_get_storage.return_value) - - @patch("tortoise.connection.ConnectionHandler._conn_storage", spec=ContextVar) - @patch("tortoise.connection.ConnectionHandler._get_storage") - def test_reset(self, mocked_get_storage: Mock, mocked_conn_storage: Mock): - first_config = {"other": "some_other_conn", "default": "diff_conn"} - second_config = {"default": "some_conn"} - mocked_get_storage.side_effect = [first_config, second_config] - final_storage = {"default": "some_conn", "other": "some_other_conn"} - self.conn_handler.reset("some_token") # type: ignore - mocked_get_storage.assert_has_calls([call(), call()]) - mocked_conn_storage.reset.assert_called_once_with("some_token") - self.assertDictEqual(final_storage, second_config) - - @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) - @patch("tortoise.connection.ConnectionHandler.get") - def test_all(self, mocked_get: Mock, mocked_db_config: Mock): - db_config = {"default": "some_conn", "other": "some_other_conn"} - - def side_effect_callable(alias): - return db_config[alias] - - mocked_get.side_effect = side_effect_callable - mocked_db_config.return_value = db_config - expected_result = ["some_conn", "some_other_conn"] - ret_val = self.conn_handler.all() - mocked_db_config.assert_called_once() - mocked_get.assert_has_calls([call("default"), call("other")], any_order=True) - self.assertEqual(ret_val, expected_result) - - @patch("tortoise.connection.ConnectionHandler.all") - @patch("tortoise.connection.ConnectionHandler.discard") - @patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) - async def test_close_all_with_discard( - self, mocked_db_config: Mock, mocked_discard: Mock, mocked_all: Mock + _ = conn_handler._discover_client_class({"engine": engine}) + + +@patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) +def test_get_db_info_present(mocked_db_config, conn_handler): + expected_ret_val = {"HOST": "some_host", "PORT": "1234"} + mocked_db_config.return_value = {"default": expected_ret_val} + ret_val = conn_handler._get_db_info("default") + assert ret_val == expected_ret_val + + +@patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) +def test_get_db_info_not_present(mocked_db_config, conn_handler): + mocked_db_config.return_value = {"default": {"HOST": "some_host", "PORT": "1234"}} + conn_alias = "blah" + with pytest.raises( + ConfigurationError, + match=f"Unable to get db settings for alias '{conn_alias}'", ): - all_conn = [AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient)] - db_config = {"default": "some_config", "other": "some_other_config"} - mocked_all.return_value = all_conn - mocked_db_config.return_value = db_config - await self.conn_handler.close_all() - mocked_all.assert_called_once() - mocked_db_config.assert_called_once() - for mock_obj in all_conn: - mock_obj.close.assert_awaited_once() - mocked_discard.assert_has_calls([call("default"), call("other")], any_order=True) - - @patch("tortoise.connection.ConnectionHandler.all") - async def test_close_all_without_discard(self, mocked_all: Mock): - all_conn = [AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient)] - mocked_all.return_value = all_conn - await self.conn_handler.close_all(discard=False) - mocked_all.assert_called_once() - for mock_obj in all_conn: - mock_obj.close.assert_awaited_once() + _ = conn_handler._get_db_info(conn_alias) + + +@pytest.mark.asyncio +@patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) +@patch("tortoise.connection.ConnectionHandler.get") +async def test_init_connections_no_db_create(mocked_get, mocked_db_config, conn_handler): + conn_1, conn_2 = AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient) + mocked_get.side_effect = [conn_1, conn_2] + mocked_db_config.return_value = { + "default": {"HOST": "some_host", "PORT": "1234"}, + "other": {"HOST": "some_other_host", "PORT": "1234"}, + } + await conn_handler._init_connections() + mocked_db_config.assert_called_once() + mocked_get.assert_has_calls([call("default"), call("other")], any_order=True) + conn_1.db_create.assert_not_awaited() + conn_2.db_create.assert_not_awaited() + + +@pytest.mark.asyncio +@patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) +@patch("tortoise.connection.ConnectionHandler.get") +async def test_init_connections_db_create(mocked_get, mocked_db_config, conn_handler): + conn_handler._create_db = True + conn_1, conn_2 = AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient) + mocked_get.side_effect = [conn_1, conn_2] + mocked_db_config.return_value = { + "default": {"HOST": "some_host", "PORT": "1234"}, + "other": {"HOST": "some_other_host", "PORT": "1234"}, + } + await conn_handler._init_connections() + mocked_db_config.assert_called_once() + mocked_get.assert_has_calls([call("default"), call("other")], any_order=True) + conn_1.db_create.assert_awaited_once() + conn_2.db_create.assert_awaited_once() + + +@patch("tortoise.connection.ConnectionHandler._get_db_info") +@patch("tortoise.connection.expand_db_url") +@patch("tortoise.connection.ConnectionHandler._discover_client_class") +def test_create_connection_db_info_str( + mocked_discover_client_class, + mocked_expand_db_url, + mocked_get_db_info, + conn_handler, +): + alias = "default" + mocked_get_db_info.return_value = "some_db_url" + mocked_expand_db_url.return_value = { + "engine": "some_engine", + "credentials": {"cred_key": "some_val"}, + } + expected_client_class = Mock(return_value="some_connection") + mocked_discover_client_class.return_value = expected_client_class + expected_db_params = {"cred_key": "some_val", "connection_name": alias} + + ret_val = conn_handler._create_connection(alias) + + mocked_get_db_info.assert_called_once_with(alias) + mocked_expand_db_url.assert_called_once_with("some_db_url") + mocked_discover_client_class.assert_called_once_with( + {"engine": "some_engine", "credentials": {"cred_key": "some_val"}} + ) + expected_client_class.assert_called_once_with(**expected_db_params) + assert ret_val == "some_connection" + + +@patch("tortoise.connection.ConnectionHandler._get_db_info") +@patch("tortoise.connection.expand_db_url") +@patch("tortoise.connection.ConnectionHandler._discover_client_class") +def test_create_connection_db_info_not_str( + mocked_discover_client_class, + mocked_expand_db_url, + mocked_get_db_info, + conn_handler, +): + alias = "default" + mocked_get_db_info.return_value = { + "engine": "some_engine", + "credentials": {"cred_key": "some_val"}, + } + expected_client_class = Mock(return_value="some_connection") + mocked_discover_client_class.return_value = expected_client_class + expected_db_params = {"cred_key": "some_val", "connection_name": alias} + + ret_val = conn_handler._create_connection(alias) + + mocked_get_db_info.assert_called_once_with(alias) + mocked_expand_db_url.assert_not_called() + mocked_discover_client_class.assert_called_once_with( + {"engine": "some_engine", "credentials": {"cred_key": "some_val"}} + ) + expected_client_class.assert_called_once_with(**expected_db_params) + assert ret_val == "some_connection" + + +def test_get_alias_present(conn_handler): + conn_handler._storage = {"default": "some_connection"} + ret_val = conn_handler.get("default") + assert ret_val == "some_connection" + + +@patch("tortoise.connection.ConnectionHandler._create_connection") +def test_get_alias_not_present(mocked_create_connection, conn_handler): + conn_handler._storage = {"default": "some_connection"} + mocked_create_connection.return_value = "some_other_connection" + ret_val = conn_handler.get("other") + mocked_create_connection.assert_called_once_with("other") + assert ret_val == "some_other_connection" + assert conn_handler._storage == {"default": "some_connection", "other": "some_other_connection"} + + +def test_set(conn_handler): + conn_handler._storage = {"default": "existing_conn"} + token = conn_handler.set("other", "some_conn") + assert conn_handler._storage == {"default": "existing_conn", "other": "some_conn"} + assert token is not None + + +def test_discard(conn_handler): + conn_handler._storage = {"default": "some_conn", "other": "other_conn"} + ret_val = conn_handler.discard("default") + assert ret_val == "some_conn" + assert conn_handler._storage == {"other": "other_conn"} + + +def test_reset(conn_handler): + conn_handler._storage = {"default": "modified_conn", "other": "other_conn"} + original_conn = Mock() + token = Mock(_handler=conn_handler, _alias="default", _old_value=original_conn, _used=False) + conn_handler.reset(token) + assert conn_handler._storage["default"] is original_conn + assert token._used is True + + +@patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) +def test_all(mocked_db_config, conn_handler): + conn_handler._storage = {"default": "some_conn", "other": "some_other_conn"} + mocked_db_config.return_value = {"default": {}, "other": {}} + ret_val = conn_handler.all() + assert set(ret_val) == {"some_conn", "some_other_conn"} + + +@pytest.mark.asyncio +@patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) +async def test_close_all_with_discard(mocked_db_config, conn_handler): + conn_1, conn_2 = AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient) + conn_handler._storage = {"default": conn_1, "other": conn_2} + conn_handler._db_config = { + "default": {}, + "other": {}, + } # Set _db_config so close_all doesn't early-return + mocked_db_config.return_value = {"default": {}, "other": {}} + await conn_handler.close_all() + conn_1.close.assert_awaited_once() + conn_2.close.assert_awaited_once() + assert conn_handler._storage == {} + + +@pytest.mark.asyncio +@patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock) +async def test_close_all_without_discard(mocked_db_config, conn_handler): + conn_1, conn_2 = AsyncMock(spec=BaseDBAsyncClient), AsyncMock(spec=BaseDBAsyncClient) + conn_handler._storage = {"default": conn_1, "other": conn_2} + conn_handler._db_config = { + "default": {}, + "other": {}, + } # Set _db_config so close_all doesn't early-return + mocked_db_config.return_value = {"default": {}, "other": {}} + await conn_handler.close_all(discard=False) + conn_1.close.assert_awaited_once() + conn_2.close.assert_awaited_once() + assert conn_handler._storage == {"default": conn_1, "other": conn_2} diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 000000000..c9afe3e16 --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,552 @@ +""" +Tests for tortoise.context module - TortoiseContext class. + +These tests verify the context-based state management for Tortoise ORM. + +Note: These tests may run with a session-scoped default context active +(created by Tortoise.init() in conftest.py). The tests account for this +by testing context isolation relative to the current state. +""" + +import pytest + +from tortoise.connection import ConnectionHandler +from tortoise.context import ( + TortoiseContext, + get_current_context, + require_context, + tortoise_test_context, +) +from tortoise.exceptions import ConfigurationError + + +class TestTortoiseContextInstantiation: + """Test cases for TortoiseContext instantiation.""" + + def test_context_instantiation_initial_state(self): + """TortoiseContext instantiation has correct initial state.""" + ctx = TortoiseContext() + + assert ctx._connections is None + assert ctx._apps is None + assert ctx._inited is False + assert ctx.inited is False + + def test_context_connections_property_lazy_creation(self): + """ConnectionHandler is lazily created on first access.""" + ctx = TortoiseContext() + + # Before access + assert ctx._connections is None + + # After access + connections = ctx.connections + assert ctx._connections is not None + assert isinstance(connections, ConnectionHandler) + + def test_context_apps_property_initially_none(self): + """Apps property is initially None.""" + ctx = TortoiseContext() + + assert ctx.apps is None + + +class TestContextManagerProtocol: + """Test cases for context manager protocol.""" + + def test_context_manager_sets_current_context(self): + """Context manager sets current context.""" + # Save original context (may be session-scoped default) + original_ctx = get_current_context() + + with TortoiseContext() as ctx: + assert get_current_context() is ctx + + # After exit, should return to original + assert get_current_context() is original_ctx + + def test_context_manager_resets_on_exit(self): + """Context manager resets on exit.""" + original_ctx = get_current_context() + + with TortoiseContext(): + pass + + assert get_current_context() is original_ctx + + def test_nested_contexts_work_correctly(self): + """Nested contexts work correctly.""" + original_ctx = get_current_context() + + with TortoiseContext() as outer: + assert get_current_context() is outer + + with TortoiseContext() as inner: + assert get_current_context() is inner + + # After inner exits, should return to outer + assert get_current_context() is outer + + # After all exit, should return to original + assert get_current_context() is original_ctx + + +class TestRequireContext: + """Test cases for require_context function.""" + + def test_require_context_raises_when_no_context(self): + """require_context raises when no context is active. + + Note: With the new architecture, Tortoise.init() creates a default context, + so this test only passes when run in complete isolation. When a session-scoped + context exists, require_context() returns it instead of raising. + """ + # This behavior depends on whether a session context exists + original_ctx = get_current_context() + if original_ctx is not None: + # Session context exists, require_context should return it + result = require_context() + assert result is original_ctx + else: + # No session context, should raise + with pytest.raises(RuntimeError) as exc_info: + require_context() + assert "No TortoiseContext is currently active" in str(exc_info.value) + + def test_require_context_returns_active_context(self): + """require_context returns the active context when one exists.""" + with TortoiseContext() as ctx: + result = require_context() + assert result is ctx + + +class TestConnectionHandlerIsolation: + """Test cases for ConnectionHandler isolation.""" + + def test_each_context_gets_own_connection_handler(self): + """Each context gets own ConnectionHandler.""" + ctx1 = TortoiseContext() + ctx2 = TortoiseContext() + + # Access connections property on both + conn1 = ctx1.connections + conn2 = ctx2.connections + + # Should be different instances + assert conn1 is not conn2 + + def test_context_connections_isolated_from_global(self): + """Context connections isolated from global.""" + ctx = TortoiseContext() + + # Context's ConnectionHandler should be completely independent + # It should not have any config yet + with pytest.raises(ConfigurationError): + ctx.connections.db_config + + +class TestAsyncContextManager: + """Test cases for async context manager protocol.""" + + @pytest.mark.asyncio + async def test_async_context_manager_sets_current_context(self): + """Async context manager sets current context.""" + original_ctx = get_current_context() + + async with TortoiseContext() as ctx: + assert get_current_context() is ctx + + # After exit, should return to original + assert get_current_context() is original_ctx + + @pytest.mark.asyncio + async def test_async_context_manager_resets_on_exit(self): + """Async context manager resets on exit.""" + original_ctx = get_current_context() + + async with TortoiseContext(): + pass + + assert get_current_context() is original_ctx + + @pytest.mark.asyncio + async def test_connections_cleaned_on_async_context_exit(self): + """Connections closed on async context exit.""" + ctx = TortoiseContext() + + # Access connections to create the handler + _ = ctx.connections + assert ctx._connections is not None + + async with ctx: + pass + + # After exit, connections should be cleaned up + assert ctx._connections is None + + @pytest.mark.asyncio + async def test_apps_cleared_on_async_context_exit(self): + """Apps cleared on context exit.""" + ctx = TortoiseContext() + + async with ctx: + # Manually set apps to simulate initialization + ctx._apps = {} + ctx._inited = True + + # After exit, apps should be cleared + assert ctx.apps is None + assert ctx.inited is False + + +class TestGetModel: + """Test cases for get_model method.""" + + def test_get_model_raises_when_not_initialized(self): + """get_model raises when not initialized.""" + ctx = TortoiseContext() + + with pytest.raises(ConfigurationError) as exc_info: + ctx.get_model("models", "User") + + assert "Context not initialized" in str(exc_info.value) + + +class TestInit: + """Test cases for init method.""" + + @pytest.mark.asyncio + async def test_init_raises_without_required_params(self): + """init() raises without required params.""" + ctx = TortoiseContext() + + with pytest.raises(ConfigurationError) as exc_info: + await ctx.init() + + assert "Must provide either 'config', 'config_file', or both 'db_url' and 'modules'" in str( + exc_info.value + ) + + @pytest.mark.asyncio + async def test_init_raises_with_only_db_url(self): + """init() raises with only db_url (no modules).""" + ctx = TortoiseContext() + + with pytest.raises(ConfigurationError) as exc_info: + await ctx.init(db_url="sqlite://:memory:") + + assert "Must provide either 'config', 'config_file', or both 'db_url' and 'modules'" in str( + exc_info.value + ) + + @pytest.mark.asyncio + async def test_init_raises_with_only_modules(self): + """init() raises with only modules (no db_url).""" + ctx = TortoiseContext() + + with pytest.raises(ConfigurationError) as exc_info: + await ctx.init(modules={"models": ["tests.testmodels"]}) + + assert "Must provide either 'config', 'config_file', or both 'db_url' and 'modules'" in str( + exc_info.value + ) + + @pytest.mark.asyncio + async def test_init_raises_with_invalid_config_no_connections(self): + """init() raises when config missing connections section.""" + ctx = TortoiseContext() + + with pytest.raises(ConfigurationError) as exc_info: + await ctx.init(config={"apps": {}}) + + assert 'Config must define "connections" section' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_init_raises_with_invalid_config_no_apps(self): + """init() raises when config missing apps section.""" + ctx = TortoiseContext() + + with pytest.raises(ConfigurationError) as exc_info: + await ctx.init(config={"connections": {}}) + + assert 'Config must define "apps" section' in str(exc_info.value) + + +class TestGenerateSchemas: + """Test cases for generate_schemas method.""" + + @pytest.mark.asyncio + async def test_generate_schemas_raises_when_not_initialized(self): + """generate_schemas() raises when context not initialized.""" + ctx = TortoiseContext() + + with pytest.raises(ConfigurationError) as exc_info: + await ctx.generate_schemas() + + assert "Context not initialized" in str(exc_info.value) + + +class TestTortoiseTestContext: + """Test cases for tortoise_test_context helper.""" + + @pytest.mark.asyncio + async def test_tortoise_test_context_creates_isolated_context(self): + """tortoise_test_context creates isolated context.""" + original_ctx = get_current_context() + + async with tortoise_test_context(["tests.testmodels"]) as ctx: + # Context should be active + assert get_current_context() is ctx + # Context should be initialized + assert ctx.inited is True + # Context should have apps + assert ctx.apps is not None + + # Context should be restored to original (session context if present) + assert get_current_context() is original_ctx + + @pytest.mark.asyncio + async def test_tortoise_test_context_multiple_isolated(self): + """Multiple tortoise_test_context calls are isolated.""" + async with tortoise_test_context(["tests.testmodels"]) as ctx1: + conn1 = ctx1.connections + + async with tortoise_test_context(["tests.testmodels"]) as ctx2: + conn2 = ctx2.connections + + # Different context instances + assert ctx1 is not ctx2 + # Different connection handlers + assert conn1 is not conn2 + + +class TestInitIntegration: + """Integration test cases for init method.""" + + @pytest.mark.asyncio + async def test_init_with_db_url_and_modules(self): + """init() with db_url and modules initializes context correctly.""" + async with TortoiseContext() as ctx: + await ctx.init( + db_url="sqlite://:memory:", + modules={"models": ["tests.testmodels"]}, + ) + + # Context should be initialized + assert ctx.inited is True + + # Connections should be populated + assert ctx._connections is not None + # Should be able to get the default connection + conn = ctx.connections.get("default") + assert conn is not None + + # Apps should be populated + assert ctx.apps is not None + # Should be able to get a model + Author = ctx.get_model("models", "Author") + assert Author.__name__ == "Author" + + @pytest.mark.asyncio + async def test_generate_schemas_creates_tables(self): + """generate_schemas() creates tables in the database.""" + async with TortoiseContext() as ctx: + await ctx.init( + db_url="sqlite://:memory:", + modules={"models": ["tests.testmodels"]}, + ) + await ctx.generate_schemas() + + # Verify tables exist by querying sqlite_master + conn = ctx.connections.get("default") + result = await conn.execute_query( + "SELECT name FROM sqlite_master WHERE type='table' AND name='author'" + ) + tables = [row["name"] for row in result[1]] + assert "author" in tables + + @pytest.mark.asyncio + async def test_full_context_lifecycle_with_crud(self): + """Full context lifecycle with model CRUD operations.""" + original_ctx = get_current_context() + + async with TortoiseContext() as ctx: + await ctx.init( + db_url="sqlite://:memory:", + modules={"models": ["tests.testmodels"]}, + ) + await ctx.generate_schemas() + + # Import model inside test to ensure it uses active context + from tests.testmodels import Author + + # CREATE + author = await Author.create(name="Test Author") + assert author.id is not None + assert author.name == "Test Author" + + # READ + fetched = await Author.get(id=author.id) + assert fetched.name == "Test Author" + + # UPDATE + fetched.name = "Updated Author" + await fetched.save() + updated = await Author.get(id=author.id) + assert updated.name == "Updated Author" + + # DELETE + await updated.delete() + count = await Author.filter(id=author.id).count() + assert count == 0 + + # Context should be restored to original (session context if present) + assert get_current_context() is original_ctx + + +class TestModelContextResolution: + """Test cases for model context resolution.""" + + @pytest.mark.asyncio + async def test_model_uses_context_connections_when_active(self): + """Model uses context when context is active.""" + async with TortoiseContext() as ctx: + await ctx.init( + db_url="sqlite://:memory:", + modules={"models": ["tests.testmodels"]}, + ) + await ctx.generate_schemas() + + from tests.testmodels import Author + + # Create a record + author = await Author.create(name="Context Author") + assert author.id is not None + + # Verify we can query it + all_authors = await Author.all() + assert len(all_authors) == 1 + assert all_authors[0].name == "Context Author" + + @pytest.mark.asyncio + async def test_sequential_contexts_isolated(self): + """Sequential contexts are isolated from each other.""" + from tests.testmodels import Author + + # First context creates its own author + async with TortoiseContext() as ctx1: + await ctx1.init( + db_url="sqlite://:memory:", + modules={"models": ["tests.testmodels"]}, + ) + await ctx1.generate_schemas() + + await Author.create(name="Author in Context 1") + authors_in_ctx1 = await Author.all() + assert len(authors_in_ctx1) == 1 + assert authors_in_ctx1[0].name == "Author in Context 1" + + # Second context should start with empty database + async with TortoiseContext() as ctx2: + await ctx2.init( + db_url="sqlite://:memory:", + modules={"models": ["tests.testmodels"]}, + ) + await ctx2.generate_schemas() + + # Should be empty - isolated from first context + authors_before = await Author.all() + assert len(authors_before) == 0, "Second context should start empty" + + await Author.create(name="Author in Context 2") + authors_in_ctx2 = await Author.all() + assert len(authors_in_ctx2) == 1 + assert authors_in_ctx2[0].name == "Author in Context 2" + + +class TestTimezoneAndRouters: + """Test cases for timezone and routers configuration.""" + + def test_context_default_timezone_settings(self): + """Context has default timezone settings.""" + ctx = TortoiseContext() + assert ctx.use_tz is False + assert ctx.timezone == "UTC" + assert ctx.routers == [] + + @pytest.mark.asyncio + async def test_init_with_timezone_settings(self): + """Context can be initialized with timezone settings.""" + async with TortoiseContext() as ctx: + await ctx.init( + db_url="sqlite://:memory:", + modules={"models": ["tests.testmodels"]}, + use_tz=True, + timezone="America/New_York", + ) + + assert ctx.use_tz is True + assert ctx.timezone == "America/New_York" + + @pytest.mark.asyncio + async def test_init_with_config_dict_timezone(self): + """Timezone settings from config dict are used.""" + async with TortoiseContext() as ctx: + await ctx.init( + config={ + "connections": {"default": "sqlite://:memory:"}, + "apps": {"models": {"models": ["tests.testmodels"]}}, + "use_tz": True, + "timezone": "Europe/London", + } + ) + + assert ctx.use_tz is True + assert ctx.timezone == "Europe/London" + + @pytest.mark.asyncio + async def test_tortoise_test_context_with_timezone(self): + """tortoise_test_context supports timezone parameters.""" + async with tortoise_test_context( + ["tests.testmodels"], + use_tz=True, + timezone="Asia/Tokyo", + ) as ctx: + assert ctx.use_tz is True + assert ctx.timezone == "Asia/Tokyo" + + +class TestTortoiseConfigValidation: + """Test cases for TortoiseConfig validation in ctx.init().""" + + @pytest.mark.asyncio + async def test_init_accepts_tortoise_config_object(self): + """ctx.init() accepts TortoiseConfig object directly.""" + from tortoise.config import AppConfig, DBUrlConfig, TortoiseConfig + + config = TortoiseConfig( + connections={"default": DBUrlConfig("sqlite://:memory:")}, + apps={"models": AppConfig(models=["tests.testmodels"])}, + use_tz=True, + timezone="UTC", + ) + + async with TortoiseContext() as ctx: + await ctx.init(config=config) + assert ctx.inited is True + assert ctx.use_tz is True + + @pytest.mark.asyncio + async def test_init_validates_dict_config(self): + """ctx.init() validates dict config and raises ConfigurationError on issues.""" + async with TortoiseContext() as ctx: + # Config with missing 'models' in app should raise ConfigurationError + # because TortoiseConfig.from_dict validates the structure + with pytest.raises(ConfigurationError) as exc_info: + await ctx.init( + config={ + "connections": {"default": "sqlite://:memory:"}, + "apps": {"models": {}}, # Missing 'models' key + } + ) + assert "models" in str(exc_info.value).lower() diff --git a/tests/test_default.py b/tests/test_default.py index a361778a8..fb3a831d1 100644 --- a/tests/test_default.py +++ b/tests/test_default.py @@ -1,49 +1,64 @@ import datetime from decimal import Decimal +import pytest +import pytest_asyncio import pytz from tests.testmodels import DefaultModel from tortoise import connections from tortoise.backends.asyncpg import AsyncpgDBClient -from tortoise.backends.mssql import MSSQLClient from tortoise.backends.mysql import MySQLClient -from tortoise.backends.oracle import OracleClient from tortoise.backends.psycopg import PsycopgClient from tortoise.backends.sqlite import SqliteClient -from tortoise.contrib import test - - -class TestDefault(test.TestCase): - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - db = connections.get("models") - if isinstance(db, MySQLClient): - await db.execute_query( - "insert into defaultmodel (`int_default`,`float_default`,`decimal_default`,`bool_default`,`char_default`,`date_default`,`datetime_default`) values (DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT)", - ) - elif isinstance(db, SqliteClient): - await db.execute_query( - "insert into defaultmodel default values", - ) - elif isinstance(db, (AsyncpgDBClient, PsycopgClient, MSSQLClient)): - await db.execute_query( - 'insert into defaultmodel ("int_default","float_default","decimal_default","bool_default","char_default","date_default","datetime_default") values (DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT)', - ) - elif isinstance(db, OracleClient): - await db.execute_query( - 'insert into "defaultmodel" ("int_default","float_default","decimal_default","bool_default","char_default","date_default","datetime_default") values (DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT)', - ) - - async def test_default(self): - default_model = await DefaultModel.first() - self.assertEqual(default_model.int_default, 1) - self.assertEqual(default_model.float_default, 1.5) - self.assertEqual(default_model.decimal_default, Decimal(1)) - self.assertTrue(default_model.bool_default) - self.assertEqual(default_model.char_default, "tortoise") - self.assertEqual(default_model.date_default, datetime.date(year=2020, month=5, day=21)) - self.assertEqual( - default_model.datetime_default, - datetime.datetime(year=2020, month=5, day=20, tzinfo=pytz.utc), + +# Optional imports for database clients that require system dependencies +try: + from tortoise.backends.mssql import MSSQLClient +except ImportError: + MSSQLClient = None # type: ignore[misc,assignment] + +try: + from tortoise.backends.oracle import OracleClient +except ImportError: + OracleClient = None # type: ignore[misc,assignment] + + +@pytest_asyncio.fixture +async def default_row(db): + """Insert a default row using raw SQL based on database type.""" + db_conn = connections.get("models") + if isinstance(db_conn, MySQLClient): + await db_conn.execute_query( + "insert into defaultmodel (`int_default`,`float_default`,`decimal_default`,`bool_default`,`char_default`,`date_default`,`datetime_default`) values (DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT)", + ) + elif isinstance(db_conn, SqliteClient): + await db_conn.execute_query( + "insert into defaultmodel default values", ) + elif isinstance(db_conn, (AsyncpgDBClient, PsycopgClient)) or ( + MSSQLClient is not None and isinstance(db_conn, MSSQLClient) + ): + await db_conn.execute_query( + 'insert into defaultmodel ("int_default","float_default","decimal_default","bool_default","char_default","date_default","datetime_default") values (DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT)', + ) + elif OracleClient is not None and isinstance(db_conn, OracleClient): + await db_conn.execute_query( + 'insert into "defaultmodel" ("int_default","float_default","decimal_default","bool_default","char_default","date_default","datetime_default") values (DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT)', + ) + yield + + +@pytest.mark.asyncio +async def test_default(default_row): + """Test that default values are correctly applied when inserting via raw SQL.""" + default_model = await DefaultModel.first() + assert default_model.int_default == 1 + assert default_model.float_default == 1.5 + assert default_model.decimal_default == Decimal(1) + assert default_model.bool_default + assert default_model.char_default == "tortoise" + assert default_model.date_default == datetime.date(year=2020, month=5, day=21) + assert default_model.datetime_default == datetime.datetime( + year=2020, month=5, day=20, tzinfo=pytz.utc + ) diff --git a/tests/test_early_init.py b/tests/test_early_init.py index 87dbd4015..22195ec1c 100644 --- a/tests/test_early_init.py +++ b/tests/test_early_init.py @@ -1,5 +1,6 @@ +import pytest + from tortoise import Tortoise, fields -from tortoise.contrib import test from tortoise.contrib.pydantic import pydantic_model_creator from tortoise.models import Model @@ -35,165 +36,134 @@ class Meta: ordering = ["name"] -class TestBasic(test.SimpleTestCase): - async def test_early_init(self): - self.maxDiff = None - Event_TooEarly = pydantic_model_creator(Event) - self.assertEqual( - Event_TooEarly.model_json_schema(), +@pytest.mark.asyncio +async def test_early_init(): + Event_TooEarly = pydantic_model_creator(Event) + assert Event_TooEarly.model_json_schema() == { + "title": "Event", + "type": "object", + "description": "The Event model docstring.

This is multiline docs.", + "properties": { + "id": { + "title": "Id", + "type": "integer", + "maximum": 2147483647, + "minimum": -2147483648, + }, + "name": { + "title": "Name", + "type": "string", + "description": "The Event NAME
It's pretty important", + "maxLength": 255, + }, + "created_at": { + "title": "Created At", + "type": "string", + "format": "date-time", + "readOnly": True, + }, + }, + "required": ["id", "name", "created_at"], + "additionalProperties": False, + } + assert Event.describe() == { + "name": "None.Event", + "app": None, + "table": "", + "abstract": False, + "description": "The Event model docstring.", + "docstring": "The Event model docstring.\n\nThis is multiline docs.", + "unique_together": [], + "indexes": [], + "pk_field": { + "name": "id", + "field_type": "IntField", + "db_column": "id", + "python_type": "int", + "generated": True, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": None, + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + "db_field_types": {"": "INT"}, + }, + "data_fields": [ { - "title": "Event", - "type": "object", - "description": "The Event model docstring.

This is multiline docs.", - "properties": { - "id": { - "title": "Id", - "type": "integer", - "maximum": 2147483647, - "minimum": -2147483648, - }, - "name": { - "title": "Name", - "type": "string", - "description": "The Event NAME
It's pretty important", - "maxLength": 255, - }, - "created_at": { - "title": "Created At", - "type": "string", - "format": "date-time", - "readOnly": True, - }, + "name": "name", + "field_type": "CharField", + "db_column": "name", + "python_type": "str", + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": None, + "description": "The Event NAME", + "docstring": "The Event NAME\nIt's pretty important", + "constraints": {"max_length": 255}, + "db_field_types": { + "": "VARCHAR(255)", + "oracle": "NVARCHAR2(255)", }, - "required": ["id", "name", "created_at"], - "additionalProperties": False, }, - ) - self.assertEqual( - Event.describe(), { - "name": "None.Event", - "app": None, - "table": "", - "abstract": False, - "description": "The Event model docstring.", - "docstring": "The Event model docstring.\n\nThis is multiline docs.", - "unique_together": [], - "indexes": [], - "pk_field": { - "name": "id", - "field_type": "IntField", - "db_column": "id", - "python_type": "int", - "generated": True, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": None, - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - "db_field_types": {"": "INT"}, + "name": "created_at", + "field_type": "DatetimeField", + "db_column": "created_at", + "python_type": "datetime.datetime", + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {"readOnly": True}, + "db_field_types": { + "": "TIMESTAMP", + "mssql": "DATETIME2", + "mysql": "DATETIME(6)", + "postgres": "TIMESTAMPTZ", + "oracle": "TIMESTAMP WITH TIME ZONE", }, - "data_fields": [ - { - "name": "name", - "field_type": "CharField", - "db_column": "name", - "python_type": "str", - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": None, - "description": "The Event NAME", - "docstring": "The Event NAME\nIt's pretty important", - "constraints": {"max_length": 255}, - "db_field_types": { - "": "VARCHAR(255)", - "oracle": "NVARCHAR2(255)", - }, - }, - { - "name": "created_at", - "field_type": "DatetimeField", - "db_column": "created_at", - "python_type": "datetime.datetime", - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {"readOnly": True}, - "db_field_types": { - "": "TIMESTAMP", - "mssql": "DATETIME2", - "mysql": "DATETIME(6)", - "postgres": "TIMESTAMPTZ", - "oracle": "TIMESTAMP WITH TIME ZONE", - }, - "auto_now_add": True, - "auto_now": False, - }, - ], - "fk_fields": [ - { - "name": "tournament", - "field_type": "ForeignKeyFieldInstance", - "python_type": "None", - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {}, - "raw_field": None, - "on_delete": "CASCADE", - "db_constraint": True, - } - ], - "backward_fk_fields": [], - "o2o_fields": [], - "backward_o2o_fields": [], - "m2m_fields": [], + "auto_now_add": True, + "auto_now": False, }, - ) + ], + "fk_fields": [ + { + "name": "tournament", + "field_type": "ForeignKeyFieldInstance", + "python_type": "None", + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + "raw_field": None, + "on_delete": "CASCADE", + "db_constraint": True, + } + ], + "backward_fk_fields": [], + "o2o_fields": [], + "backward_o2o_fields": [], + "m2m_fields": [], + } - Tortoise.init_models(["tests.test_early_init"], "models") + Tortoise.init_models(["tests.test_early_init"], "models") - Event_Pydantic = pydantic_model_creator(Event) - self.assertEqual( - Event_Pydantic.model_json_schema(), - { - "$defs": { - "Tournament_aapnxb_leaf": { - "additionalProperties": False, - "properties": { - "id": { - "maximum": 2147483647, - "minimum": -2147483648, - "title": "Id", - "type": "integer", - }, - "name": {"maxLength": 100, "title": "Name", "type": "string"}, - "created_at": { - "format": "date-time", - "readOnly": True, - "title": "Created At", - "type": "string", - }, - }, - "required": ["id", "name", "created_at"], - "title": "Tournament", - "type": "object", - } - }, + Event_Pydantic = pydantic_model_creator(Event) + assert Event_Pydantic.model_json_schema() == { + "$defs": { + "Tournament_aapnxb_leaf": { "additionalProperties": False, - "description": "The Event model docstring.

This is multiline docs.", "properties": { "id": { "maximum": 2147483647, @@ -201,134 +171,152 @@ async def test_early_init(self): "title": "Id", "type": "integer", }, - "name": { - "description": "The Event NAME
It's pretty important", - "maxLength": 255, - "title": "Name", - "type": "string", - }, + "name": {"maxLength": 100, "title": "Name", "type": "string"}, "created_at": { "format": "date-time", "readOnly": True, "title": "Created At", "type": "string", }, - "tournament": { - "anyOf": [{"$ref": "#/$defs/Tournament_aapnxb_leaf"}, {"type": "null"}], - "nullable": True, - "title": "Tournament", - }, }, - "required": ["id", "name", "created_at", "tournament"], - "title": "Event", + "required": ["id", "name", "created_at"], + "title": "Tournament", "type": "object", + } + }, + "additionalProperties": False, + "description": "The Event model docstring.

This is multiline docs.", + "properties": { + "id": { + "maximum": 2147483647, + "minimum": -2147483648, + "title": "Id", + "type": "integer", + }, + "name": { + "description": "The Event NAME
It's pretty important", + "maxLength": 255, + "title": "Name", + "type": "string", + }, + "created_at": { + "format": "date-time", + "readOnly": True, + "title": "Created At", + "type": "string", }, - ) - self.assertEqual( - Event.describe(), + "tournament": { + "anyOf": [{"$ref": "#/$defs/Tournament_aapnxb_leaf"}, {"type": "null"}], + "nullable": True, + "title": "Tournament", + }, + }, + "required": ["id", "name", "created_at", "tournament"], + "title": "Event", + "type": "object", + } + assert Event.describe() == { + "name": "models.Event", + "app": "models", + "table": "event", + "abstract": False, + "description": "The Event model docstring.", + "docstring": "The Event model docstring.\n\nThis is multiline docs.", + "unique_together": [], + "indexes": [], + "pk_field": { + "name": "id", + "field_type": "IntField", + "db_column": "id", + "db_field_types": {"": "INT"}, + "python_type": "int", + "generated": True, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": None, + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + }, + "data_fields": [ { - "name": "models.Event", - "app": "models", - "table": "event", - "abstract": False, - "description": "The Event model docstring.", - "docstring": "The Event model docstring.\n\nThis is multiline docs.", - "unique_together": [], - "indexes": [], - "pk_field": { - "name": "id", - "field_type": "IntField", - "db_column": "id", - "db_field_types": {"": "INT"}, - "python_type": "int", - "generated": True, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": None, - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, + "name": "name", + "field_type": "CharField", + "db_column": "name", + "db_field_types": { + "": "VARCHAR(255)", + "oracle": "NVARCHAR2(255)", }, - "data_fields": [ - { - "name": "name", - "field_type": "CharField", - "db_column": "name", - "db_field_types": { - "": "VARCHAR(255)", - "oracle": "NVARCHAR2(255)", - }, - "python_type": "str", - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": None, - "description": "The Event NAME", - "docstring": "The Event NAME\nIt's pretty important", - "constraints": {"max_length": 255}, - }, - { - "name": "created_at", - "field_type": "DatetimeField", - "db_column": "created_at", - "db_field_types": { - "": "TIMESTAMP", - "mssql": "DATETIME2", - "mysql": "DATETIME(6)", - "postgres": "TIMESTAMPTZ", - "oracle": "TIMESTAMP WITH TIME ZONE", - }, - "python_type": "datetime.datetime", - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {"readOnly": True}, - "auto_now_add": True, - "auto_now": False, - }, - { - "name": "tournament_id", - "field_type": "IntField", - "db_column": "tournament_id", - "db_field_types": {"": "INT"}, - "python_type": "int", - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - ], - "fk_fields": [ - { - "name": "tournament", - "field_type": "ForeignKeyFieldInstance", - "raw_field": "tournament_id", - "python_type": "models.Tournament", - "generated": False, - "nullable": True, - "on_delete": "CASCADE", - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {}, - "db_constraint": True, - } - ], - "backward_fk_fields": [], - "o2o_fields": [], - "backward_o2o_fields": [], - "m2m_fields": [], + "python_type": "str", + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": None, + "description": "The Event NAME", + "docstring": "The Event NAME\nIt's pretty important", + "constraints": {"max_length": 255}, + }, + { + "name": "created_at", + "field_type": "DatetimeField", + "db_column": "created_at", + "db_field_types": { + "": "TIMESTAMP", + "mssql": "DATETIME2", + "mysql": "DATETIME(6)", + "postgres": "TIMESTAMPTZ", + "oracle": "TIMESTAMP WITH TIME ZONE", + }, + "python_type": "datetime.datetime", + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {"readOnly": True}, + "auto_now_add": True, + "auto_now": False, }, - ) + { + "name": "tournament_id", + "field_type": "IntField", + "db_column": "tournament_id", + "db_field_types": {"": "INT"}, + "python_type": "int", + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + }, + ], + "fk_fields": [ + { + "name": "tournament", + "field_type": "ForeignKeyFieldInstance", + "raw_field": "tournament_id", + "python_type": "models.Tournament", + "generated": False, + "nullable": True, + "on_delete": "CASCADE", + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + "db_constraint": True, + } + ], + "backward_fk_fields": [], + "o2o_fields": [], + "backward_o2o_fields": [], + "m2m_fields": [], + } diff --git a/tests/test_f.py b/tests/test_f.py index 32f0f4611..487479a6a 100644 --- a/tests/test_f.py +++ b/tests/test_f.py @@ -1,144 +1,163 @@ +import pytest + from tests.testmodels import JSONFields -from tortoise.contrib import test +from tortoise.contrib.test import requireCapability from tortoise.expressions import Connector, F -class TestF(test.TestCase): - def test_arithmetic(self): - f = F("name") - - negated = -f - self.assertEqual(negated.connector, Connector.mul) - self.assertEqual(negated.right.value, -1) - - added = f + 1 - self.assertEqual(added.connector, Connector.add) - self.assertEqual(added.right.value, 1) - - radded = 1 + f - self.assertEqual(radded.connector, Connector.add) - self.assertEqual(radded.left.value, 1) - self.assertEqual(radded.right, f) - - subbed = f - 1 - self.assertEqual(subbed.connector, Connector.sub) - self.assertEqual(subbed.right.value, 1) - - rsubbed = 1 - f - self.assertEqual(rsubbed.connector, Connector.sub) - self.assertEqual(rsubbed.left.value, 1) - - mulled = f * 2 - self.assertEqual(mulled.connector, Connector.mul) - self.assertEqual(mulled.right.value, 2) - - rmulled = 2 * f - self.assertEqual(rmulled.connector, Connector.mul) - self.assertEqual(rmulled.left.value, 2) - - divved = f / 2 - self.assertEqual(divved.connector, Connector.div) - self.assertEqual(divved.right.value, 2) - - rdivved = 2 / f - self.assertEqual(rdivved.connector, Connector.div) - self.assertEqual(rdivved.left.value, 2) - - powed = f**2 - self.assertEqual(powed.connector, Connector.pow) - self.assertEqual(powed.right.value, 2) - - rpowed = 2**f - self.assertEqual(rpowed.connector, Connector.pow) - self.assertEqual(rpowed.left.value, 2) - - modded = f % 2 - self.assertEqual(modded.connector, Connector.mod) - self.assertEqual(modded.right.value, 2) - - rmodded = 2 % f - self.assertEqual(rmodded.connector, Connector.mod) - self.assertEqual(rmodded.left.value, 2) - - @test.requireCapability(support_json_attributes=True) - async def test_values_with_json_field_attribute(self): - await JSONFields.create(data='{"attribute": 1}') - res = await JSONFields.annotate(attribute=F("data__attribute")).first() - self.assertEqual(int(res.attribute), 1) - - @test.requireCapability(support_json_attributes=True) - async def test_values_with_json_field_attribute_of_attribute(self): - await JSONFields.create(data='{"attribute": {"subattribute": "value"}}') - res = await JSONFields.annotate(subattribute=F("data__attribute__subattribute")).first() - self.assertEqual(res.subattribute, "value") - - @test.requireCapability(support_json_attributes=True) - async def test_values_with_json_field_str_array_element(self): - await JSONFields.create(data='["a", "b", "c"]') - res = await JSONFields.annotate(array_element=F("data__0")).first() - self.assertEqual(res.array_element, "a") - res = await JSONFields.annotate(array_element=F("data__1")).first() - self.assertEqual(res.array_element, "b") - res = await JSONFields.annotate(array_element=F("data__2")).first() - self.assertEqual(res.array_element, "c") - res = await JSONFields.annotate(array_element=F("data__3")).first() - self.assertIsNone(res.array_element) - - @test.requireCapability(support_json_attributes=True) - async def test_values_with_json_field_array_attribute(self): - await JSONFields.create(data='{"array": ["a", "b", "c"]}') - res = await JSONFields.annotate(array_attribute=F("data__array__0")).first() - self.assertEqual(res.array_attribute, "a") - res = await JSONFields.annotate(array_attribute=F("data__array__1")).first() - self.assertEqual(res.array_attribute, "b") - res = await JSONFields.annotate(array_attribute=F("data__array__2")).first() - self.assertEqual(res.array_attribute, "c") - - @test.requireCapability(support_json_attributes=True) - async def test_values_with_json_field_int_array_element(self): - """ - Among the supported dialects, only SQLite will return the correct type. - """ - await JSONFields.create(data="[1, 2, 3]") - res = await JSONFields.annotate(array_element=F("data__0")).first() - self.assertEqual(int(res.array_element), 1) - res = await JSONFields.annotate(array_element=F("data__1")).first() - self.assertEqual(int(res.array_element), 2) - res = await JSONFields.annotate(array_element=F("data__2")).first() - self.assertEqual(int(res.array_element), 3) - res = await JSONFields.annotate(array_element=F("data__3")).first() - self.assertIsNone(res.array_element) - - @test.requireCapability(support_json_attributes=True) - async def test_filter_with_json_field_attribute(self): - exp = await JSONFields.create(data='{"attribute": "a"}') - res = ( - await JSONFields.annotate(attribute=F("data__attribute")).filter(attribute="a").first() - ) - self.assertEqual(res.id, exp.id) - res = ( - await JSONFields.annotate(attribute=F("data__attribute")).filter(attribute="b").first() - ) - self.assertIsNone(res) - - @test.requireCapability(support_json_attributes=True) - async def test_filter_with_json_field_attribute_of_attribute(self): - exp = await JSONFields.create(data='{"attribute": {"subattribute": "value"}}') - res = ( - await JSONFields.annotate(subattribute=F("data__attribute__subattribute")) - .filter(subattribute="value") - .first() - ) - self.assertEqual(res.id, exp.id) - - @test.requireCapability(support_json_attributes=True) - async def test_filter_with_json_field_str_array_element(self): - exp = await JSONFields.create(data='["a", "b", "c"]') - res = ( - await JSONFields.annotate(array_element=F("data__0")).filter(array_element="a").first() - ) - self.assertEqual(res.id, exp.id) - res = ( - await JSONFields.annotate(array_element=F("data__1")).filter(array_element="b").first() - ) - self.assertEqual(res.id, exp.id) +def test_arithmetic(): + """Test F expression arithmetic operations.""" + f = F("name") + + negated = -f + assert negated.connector == Connector.mul + assert negated.right.value == -1 + + added = f + 1 + assert added.connector == Connector.add + assert added.right.value == 1 + + radded = 1 + f + assert radded.connector == Connector.add + assert radded.left.value == 1 + assert radded.right == f + + subbed = f - 1 + assert subbed.connector == Connector.sub + assert subbed.right.value == 1 + + rsubbed = 1 - f + assert rsubbed.connector == Connector.sub + assert rsubbed.left.value == 1 + + mulled = f * 2 + assert mulled.connector == Connector.mul + assert mulled.right.value == 2 + + rmulled = 2 * f + assert rmulled.connector == Connector.mul + assert rmulled.left.value == 2 + + divved = f / 2 + assert divved.connector == Connector.div + assert divved.right.value == 2 + + rdivved = 2 / f + assert rdivved.connector == Connector.div + assert rdivved.left.value == 2 + + powed = f**2 + assert powed.connector == Connector.pow + assert powed.right.value == 2 + + rpowed = 2**f + assert rpowed.connector == Connector.pow + assert rpowed.left.value == 2 + + modded = f % 2 + assert modded.connector == Connector.mod + assert modded.right.value == 2 + + rmodded = 2 % f + assert rmodded.connector == Connector.mod + assert rmodded.left.value == 2 + + +@requireCapability(support_json_attributes=True) +@pytest.mark.asyncio +async def test_values_with_json_field_attribute(db): + """Test F expression with JSON field attribute.""" + await JSONFields.create(data='{"attribute": 1}') + res = await JSONFields.annotate(attribute=F("data__attribute")).first() + assert int(res.attribute) == 1 + + +@requireCapability(support_json_attributes=True) +@pytest.mark.asyncio +async def test_values_with_json_field_attribute_of_attribute(db): + """Test F expression with nested JSON field attribute.""" + await JSONFields.create(data='{"attribute": {"subattribute": "value"}}') + res = await JSONFields.annotate(subattribute=F("data__attribute__subattribute")).first() + assert res.subattribute == "value" + + +@requireCapability(support_json_attributes=True) +@pytest.mark.asyncio +async def test_values_with_json_field_str_array_element(db): + """Test F expression with JSON field string array element.""" + await JSONFields.create(data='["a", "b", "c"]') + res = await JSONFields.annotate(array_element=F("data__0")).first() + assert res.array_element == "a" + res = await JSONFields.annotate(array_element=F("data__1")).first() + assert res.array_element == "b" + res = await JSONFields.annotate(array_element=F("data__2")).first() + assert res.array_element == "c" + res = await JSONFields.annotate(array_element=F("data__3")).first() + assert res.array_element is None + + +@requireCapability(support_json_attributes=True) +@pytest.mark.asyncio +async def test_values_with_json_field_array_attribute(db): + """Test F expression with JSON field array attribute.""" + await JSONFields.create(data='{"array": ["a", "b", "c"]}') + res = await JSONFields.annotate(array_attribute=F("data__array__0")).first() + assert res.array_attribute == "a" + res = await JSONFields.annotate(array_attribute=F("data__array__1")).first() + assert res.array_attribute == "b" + res = await JSONFields.annotate(array_attribute=F("data__array__2")).first() + assert res.array_attribute == "c" + + +@requireCapability(support_json_attributes=True) +@pytest.mark.asyncio +async def test_values_with_json_field_int_array_element(db): + """ + Test F expression with JSON field integer array element. + + Among the supported dialects, only SQLite will return the correct type. + """ + await JSONFields.create(data="[1, 2, 3]") + res = await JSONFields.annotate(array_element=F("data__0")).first() + assert int(res.array_element) == 1 + res = await JSONFields.annotate(array_element=F("data__1")).first() + assert int(res.array_element) == 2 + res = await JSONFields.annotate(array_element=F("data__2")).first() + assert int(res.array_element) == 3 + res = await JSONFields.annotate(array_element=F("data__3")).first() + assert res.array_element is None + + +@requireCapability(support_json_attributes=True) +@pytest.mark.asyncio +async def test_filter_with_json_field_attribute(db): + """Test F expression filter with JSON field attribute.""" + exp = await JSONFields.create(data='{"attribute": "a"}') + res = await JSONFields.annotate(attribute=F("data__attribute")).filter(attribute="a").first() + assert res.id == exp.id + res = await JSONFields.annotate(attribute=F("data__attribute")).filter(attribute="b").first() + assert res is None + + +@requireCapability(support_json_attributes=True) +@pytest.mark.asyncio +async def test_filter_with_json_field_attribute_of_attribute(db): + """Test F expression filter with nested JSON field attribute.""" + exp = await JSONFields.create(data='{"attribute": {"subattribute": "value"}}') + res = ( + await JSONFields.annotate(subattribute=F("data__attribute__subattribute")) + .filter(subattribute="value") + .first() + ) + assert res.id == exp.id + + +@requireCapability(support_json_attributes=True) +@pytest.mark.asyncio +async def test_filter_with_json_field_str_array_element(db): + """Test F expression filter with JSON field string array element.""" + exp = await JSONFields.create(data='["a", "b", "c"]') + res = await JSONFields.annotate(array_element=F("data__0")).filter(array_element="a").first() + assert res.id == exp.id + res = await JSONFields.annotate(array_element=F("data__1")).filter(array_element="b").first() + assert res.id == exp.id diff --git a/tests/test_filtering.py b/tests/test_filtering.py index dde03376b..bc981b9c4 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -1,5 +1,7 @@ import datetime +import pytest + from tests.testmodels import ( DateFields, DatetimeFields, @@ -15,514 +17,582 @@ from tortoise.functions import Coalesce, Count, Length, Lower, Max, Trim, Upper -class TestFiltering(test.TestCase): - async def test_filtering(self): - tournament = Tournament(name="Tournament") - await tournament.save() - - second_tournament = Tournament(name="Tournament 2") - await second_tournament.save() - - event_first = Event(name="1", tournament=tournament) - await event_first.save() - event_second = Event(name="2", tournament=second_tournament) - await event_second.save() - event_third = Event(name="3", tournament=tournament) - await event_third.save() - event_forth = Event(name="4", tournament=second_tournament) - await event_forth.save() - - team_first = Team(name="First") - await team_first.save() - team_second = Team(name="Second") - await team_second.save() - - await team_first.events.add(event_first) - await event_second.participants.add(team_second) - - found_events = ( - await Event.filter(Q(pk__in=[event_first.pk, event_second.pk]) | Q(name="3")) - .filter(participants__not=team_second.id) - .order_by("name", "tournament_id") - .distinct() - ) - self.assertEqual(len(found_events), 2) - self.assertEqual(found_events[0].pk, event_first.pk) - self.assertEqual(found_events[1].pk, event_third.pk) - await Team.filter(events__tournament_id=tournament.id).order_by("-events__name") - await Tournament.filter(events__name__in=["1", "3"]).distinct() - - teams = await Team.filter(name__icontains="CON") - self.assertEqual(len(teams), 1) - self.assertEqual(teams[0].name, "Second") - - teams = await Team.filter(name__iexact="SeCoNd") - self.assertEqual(len(teams), 1) - self.assertEqual(teams[0].name, "Second") - - tournaments = await Tournament.filter(events__participants__name__startswith="Fir") - self.assertEqual(len(tournaments), 1) - self.assertEqual(tournaments[0], tournament) - - async def test_q_object_backward_related_query(self): - await Tournament.create(name="0") - tournament = await Tournament.create(name="Tournament") - event = await Event.create(name="1", tournament=tournament) - fetched_tournament = await Tournament.filter(events=event.event_id).first() - self.assertEqual(fetched_tournament.id, tournament.id) - - fetched_tournament = await Tournament.filter(Q(events=event.event_id)).first() - self.assertEqual(fetched_tournament.id, tournament.id) - - async def test_q_object_related_query(self): - tournament_first = await Tournament.create(name="0") - tournament_second = await Tournament.create(name="1") - event = await Event.create(name="1", tournament=tournament_second) - await Event.create(name="1", tournament=tournament_first) - - fetched_event = await Event.filter(tournament=tournament_second).first() - self.assertEqual(fetched_event.pk, event.pk) - - fetched_event = await Event.filter(Q(tournament=tournament_second)).first() - self.assertEqual(fetched_event.pk, event.pk) - - fetched_event = await Event.filter(Q(tournament=tournament_second.id)).first() - self.assertEqual(fetched_event.pk, event.pk) - - async def test_null_filter(self): - tournament = await Tournament.create(name="Tournament") - reporter = await Reporter.create(name="John") - await Event.create(name="2", tournament=tournament, reporter=reporter) - event = await Event.create(name="1", tournament=tournament) - fetched_events = await Event.filter(reporter=None) - self.assertEqual(len(fetched_events), 1) - self.assertEqual(fetched_events[0].pk, event.pk) - - async def test_exclude(self): - await Tournament.create(name="0") - tournament = await Tournament.create(name="1") - - tournaments = await Tournament.exclude(name="0") - self.assertEqual(len(tournaments), 1) - self.assertEqual(tournaments[0].name, tournament.name) - - async def test_exclude_with_filter(self): - await Tournament.create(name="0") - tournament = await Tournament.create(name="1") - await Tournament.create(name="2") - - tournaments = await Tournament.exclude(name="0").filter(id=tournament.id) - self.assertEqual(len(tournaments), 1) - self.assertEqual(tournaments[0].name, tournament.name) - - async def test_filter_null_on_related(self): - tournament = await Tournament.create(name="Tournament") - reporter = await Reporter.create(name="John") - event_first = await Event.create(name="1", tournament=tournament, reporter=reporter) - event_second = await Event.create(name="2", tournament=tournament) - - team_first = await Team.create(name="1") - team_second = await Team.create(name="2") - await event_first.participants.add(team_first) - await event_second.participants.add(team_second) - - fetched_teams = await Team.filter(events__reporter=None) - self.assertEqual(len(fetched_teams), 1) - self.assertEqual(fetched_teams[0].id, team_second.id) - - async def test_filter_or(self): - await Tournament.create(name="0") - await Tournament.create(name="1") - await Tournament.create(name="2") - - tournaments = await Tournament.filter(Q(name="1") | Q(name="2")) - self.assertEqual(len(tournaments), 2) - self.assertSetEqual({t.name for t in tournaments}, {"1", "2"}) - - async def test_filter_not(self): - await Tournament.create(name="0") - await Tournament.create(name="1") - - tournaments = await Tournament.filter(~Q(name="1")) - self.assertEqual(len(tournaments), 1) - self.assertEqual(tournaments[0].name, "0") - - async def test_filter_with_f_expression(self): - await IntFields.create(intnum=1, intnum_null=1) - await IntFields.create(intnum=2, intnum_null=1) - self.assertEqual(await IntFields.filter(intnum=F("intnum_null")).count(), 1) - self.assertEqual(await IntFields.filter(intnum__gte=F("intnum_null")).count(), 2) - self.assertEqual( - await IntFields.filter(intnum=F("intnum_null") + F("intnum_null")).count(), 1 - ) +@pytest.mark.asyncio +async def test_filtering(db): + tournament = Tournament(name="Tournament") + await tournament.save() + + second_tournament = Tournament(name="Tournament 2") + await second_tournament.save() + + event_first = Event(name="1", tournament=tournament) + await event_first.save() + event_second = Event(name="2", tournament=second_tournament) + await event_second.save() + event_third = Event(name="3", tournament=tournament) + await event_third.save() + event_forth = Event(name="4", tournament=second_tournament) + await event_forth.save() + + team_first = Team(name="First") + await team_first.save() + team_second = Team(name="Second") + await team_second.save() + + await team_first.events.add(event_first) + await event_second.participants.add(team_second) + + found_events = ( + await Event.filter(Q(pk__in=[event_first.pk, event_second.pk]) | Q(name="3")) + .filter(participants__not=team_second.id) + .order_by("name", "tournament_id") + .distinct() + ) + assert len(found_events) == 2 + assert found_events[0].pk == event_first.pk + assert found_events[1].pk == event_third.pk + await Team.filter(events__tournament_id=tournament.id).order_by("-events__name") + await Tournament.filter(events__name__in=["1", "3"]).distinct() + + teams = await Team.filter(name__icontains="CON") + assert len(teams) == 1 + assert teams[0].name == "Second" - async def test_filter_not_with_or(self): - await Tournament.create(name="0") - await Tournament.create(name="1") - await Tournament.create(name="2") - - tournaments = await Tournament.filter(Q(name="1") | ~Q(name="2")) - self.assertEqual(len(tournaments), 2) - self.assertSetEqual({t.name for t in tournaments}, {"0", "1"}) - - @test.requireCapability(dialect=In("postgres", "mysql")) - async def test_filter_exact(self): - obj = await DatetimeFields.create( - datetime=datetime.datetime( - year=2020, month=5, day=20, hour=0, minute=0, second=0, microsecond=0 - ) - ) - self.assertEqual(await DatetimeFields.filter(datetime__year=2020).count(), 1) - self.assertEqual(await DatetimeFields.filter(datetime__quarter=2).count(), 1) - self.assertEqual(await DatetimeFields.filter(datetime__month=5).count(), 1) - self.assertEqual(await DatetimeFields.filter(datetime__day=20).count(), 1) - if test._TORTOISE_TEST_DB.startswith("mysql"): - self.assertEqual(await DatetimeFields.filter(datetime__week=20).count(), 1) - self.assertEqual(await DatetimeFields.filter(datetime__hour=0).count(), 1) - else: - # PostgreSQL enables tzinfo by default - dt = obj.datetime.astimezone() - week = dt.isocalendar()[1] - self.assertEqual(await DatetimeFields.filter(datetime__week=week).count(), 1) - self.assertEqual(await DatetimeFields.filter(datetime__hour=dt.hour).count(), 1) - self.assertEqual(await DatetimeFields.filter(datetime__minute=0).count(), 1) - self.assertEqual(await DatetimeFields.filter(datetime__second=0).count(), 1) - self.assertEqual(await DatetimeFields.filter(datetime__microsecond=0).count(), 1) - - await DateFields.create(date=datetime.date(year=2021, month=6, day=21)) - self.assertEqual(await DateFields.filter(date__year=-2021).count(), 0) - self.assertEqual(await DateFields.filter(date__year=2021).count(), 1) - self.assertEqual(await DateFields.filter(date__month=6).count(), 1) - self.assertEqual(await DateFields.filter(date__day=21).count(), 1) - self.assertEqual(await DateFields.filter(date__year="2021").count(), 1) - self.assertEqual(await DateFields.filter(date__year=2021.0).count(), 1) - self.assertEqual(await DateFields.filter(date="20210621").count(), 1) - self.assertEqual(await DateFields.filter(date="2021-06-21").count(), 1) - self.assertEqual( - await DateFields.filter(date=datetime.date(year=2021, month=6, day=21)).count(), 1 - ) + teams = await Team.filter(name__iexact="SeCoNd") + assert len(teams) == 1 + assert teams[0].name == "Second" - async def test_filter_by_aggregation_field(self): - tournament = await Tournament.create(name="0") - await Tournament.create(name="1") - await Event.create(name="2", tournament=tournament) + tournaments = await Tournament.filter(events__participants__name__startswith="Fir") + assert len(tournaments) == 1 + assert tournaments[0] == tournament - tournaments = await Tournament.annotate(events_count=Count("events")).filter(events_count=1) - self.assertEqual(len(tournaments), 1) - self.assertEqual(tournaments[0].id, tournament.id) - async def test_filter_by_aggregation_field_with_and(self): - tournament = await Tournament.create(name="0") - tournament_second = await Tournament.create(name="1") - await Event.create(name="1", tournament=tournament) - await Event.create(name="2", tournament=tournament_second) +@pytest.mark.asyncio +async def test_q_object_backward_related_query(db): + await Tournament.create(name="0") + tournament = await Tournament.create(name="Tournament") + event = await Event.create(name="1", tournament=tournament) + fetched_tournament = await Tournament.filter(events=event.event_id).first() + assert fetched_tournament.id == tournament.id - tournaments = await Tournament.annotate(events_count=Count("events")).filter( - events_count=1, name="0" - ) - self.assertEqual(len(tournaments), 1) - self.assertEqual(tournaments[0].id, tournament.id) + fetched_tournament = await Tournament.filter(Q(events=event.event_id)).first() + assert fetched_tournament.id == tournament.id - async def test_filter_by_aggregation_field_with_and_as_one_node(self): - tournament = await Tournament.create(name="0") - tournament_second = await Tournament.create(name="1") - await Event.create(name="1", tournament=tournament) - await Event.create(name="2", tournament=tournament_second) - tournaments = await Tournament.annotate(events_count=Count("events")).filter( - Q(events_count=1, name="0") - ) - self.assertEqual(len(tournaments), 1) - self.assertEqual(tournaments[0].id, tournament.id) +@pytest.mark.asyncio +async def test_q_object_related_query(db): + tournament_first = await Tournament.create(name="0") + tournament_second = await Tournament.create(name="1") + event = await Event.create(name="1", tournament=tournament_second) + await Event.create(name="1", tournament=tournament_first) + + fetched_event = await Event.filter(tournament=tournament_second).first() + assert fetched_event.pk == event.pk - async def test_filter_by_aggregation_field_with_and_as_two_nodes(self): - tournament = await Tournament.create(name="0") - tournament_second = await Tournament.create(name="1") - await Event.create(name="1", tournament=tournament) - await Event.create(name="2", tournament=tournament_second) + fetched_event = await Event.filter(Q(tournament=tournament_second)).first() + assert fetched_event.pk == event.pk + + fetched_event = await Event.filter(Q(tournament=tournament_second.id)).first() + assert fetched_event.pk == event.pk + + +@pytest.mark.asyncio +async def test_null_filter(db): + tournament = await Tournament.create(name="Tournament") + reporter = await Reporter.create(name="John") + await Event.create(name="2", tournament=tournament, reporter=reporter) + event = await Event.create(name="1", tournament=tournament) + fetched_events = await Event.filter(reporter=None) + assert len(fetched_events) == 1 + assert fetched_events[0].pk == event.pk + + +@pytest.mark.asyncio +async def test_exclude(db): + await Tournament.create(name="0") + tournament = await Tournament.create(name="1") + + tournaments = await Tournament.exclude(name="0") + assert len(tournaments) == 1 + assert tournaments[0].name == tournament.name + + +@pytest.mark.asyncio +async def test_exclude_with_filter(db): + await Tournament.create(name="0") + tournament = await Tournament.create(name="1") + await Tournament.create(name="2") + + tournaments = await Tournament.exclude(name="0").filter(id=tournament.id) + assert len(tournaments) == 1 + assert tournaments[0].name == tournament.name + + +@pytest.mark.asyncio +async def test_filter_null_on_related(db): + tournament = await Tournament.create(name="Tournament") + reporter = await Reporter.create(name="John") + event_first = await Event.create(name="1", tournament=tournament, reporter=reporter) + event_second = await Event.create(name="2", tournament=tournament) + + team_first = await Team.create(name="1") + team_second = await Team.create(name="2") + await event_first.participants.add(team_first) + await event_second.participants.add(team_second) + + fetched_teams = await Team.filter(events__reporter=None) + assert len(fetched_teams) == 1 + assert fetched_teams[0].id == team_second.id + + +@pytest.mark.asyncio +async def test_filter_or(db): + await Tournament.create(name="0") + await Tournament.create(name="1") + await Tournament.create(name="2") - tournaments = await Tournament.annotate(events_count=Count("events")).filter( - Q(events_count=1) & Q(name="0") - ) - self.assertEqual(len(tournaments), 1) - self.assertEqual(tournaments[0].id, tournament.id) + tournaments = await Tournament.filter(Q(name="1") | Q(name="2")) + assert len(tournaments) == 2 + assert {t.name for t in tournaments} == {"1", "2"} - async def test_filter_by_aggregation_field_with_or(self): - tournament = await Tournament.create(name="0") - await Tournament.create(name="1") - await Tournament.create(name="2") - await Event.create(name="1", tournament=tournament) - tournaments = await Tournament.annotate(events_count=Count("events")).filter( - Q(events_count=1) | Q(name="2") - ) - self.assertEqual(len(tournaments), 2) - self.assertSetEqual({t.name for t in tournaments}, {"0", "2"}) +@pytest.mark.asyncio +async def test_filter_not(db): + await Tournament.create(name="0") + await Tournament.create(name="1") - async def test_filter_by_aggregation_field_with_or_reversed(self): - tournament = await Tournament.create(name="0") - await Tournament.create(name="1") - await Tournament.create(name="2") - await Event.create(name="1", tournament=tournament) + tournaments = await Tournament.filter(~Q(name="1")) + assert len(tournaments) == 1 + assert tournaments[0].name == "0" - tournaments = await Tournament.annotate(events_count=Count("events")).filter( - Q(name="2") | Q(events_count=1) - ) - self.assertEqual(len(tournaments), 2) - self.assertSetEqual({t.name for t in tournaments}, {"0", "2"}) - async def test_filter_by_aggregation_field_with_or_as_one_node(self): - tournament = await Tournament.create(name="0") - await Tournament.create(name="1") - await Tournament.create(name="2") - await Event.create(name="1", tournament=tournament) +@pytest.mark.asyncio +async def test_filter_with_f_expression(db): + await IntFields.create(intnum=1, intnum_null=1) + await IntFields.create(intnum=2, intnum_null=1) + assert await IntFields.filter(intnum=F("intnum_null")).count() == 1 + assert await IntFields.filter(intnum__gte=F("intnum_null")).count() == 2 + assert await IntFields.filter(intnum=F("intnum_null") + F("intnum_null")).count() == 1 - tournaments = await Tournament.annotate(events_count=Count("events")).filter( - Q(events_count=1, name="2", join_type=Q.OR) - ) - self.assertEqual(len(tournaments), 2) - self.assertSetEqual({t.name for t in tournaments}, {"0", "2"}) - async def test_filter_by_aggregation_field_with_not(self): - tournament = await Tournament.create(name="0") - tournament_second = await Tournament.create(name="1") - await Event.create(name="1", tournament=tournament) - await Event.create(name="2", tournament=tournament_second) +@pytest.mark.asyncio +async def test_filter_not_with_or(db): + await Tournament.create(name="0") + await Tournament.create(name="1") + await Tournament.create(name="2") - tournaments = await Tournament.annotate(events_count=Count("events")).filter( - ~Q(events_count=1, name="0") - ) - self.assertEqual(len(tournaments), 1) - self.assertEqual(tournaments[0].id, tournament_second.id) + tournaments = await Tournament.filter(Q(name="1") | ~Q(name="2")) + assert len(tournaments) == 2 + assert {t.name for t in tournaments} == {"0", "1"} - async def test_filter_by_aggregation_field_with_or_not(self): - tournament = await Tournament.create(name="0") - await Tournament.create(name="1") - await Tournament.create(name="2") - await Event.create(name="1", tournament=tournament) - tournaments = await Tournament.annotate(events_count=Count("events")).filter( - ~(Q(events_count=1) | Q(name="2")) +@test.requireCapability(dialect=In("postgres", "mysql")) +@pytest.mark.asyncio +async def test_filter_exact(db): + obj = await DatetimeFields.create( + datetime=datetime.datetime( + year=2020, month=5, day=20, hour=0, minute=0, second=0, microsecond=0 ) - self.assertEqual(len(tournaments), 1) - self.assertSetEqual({t.name for t in tournaments}, {"1"}) + ) + assert await DatetimeFields.filter(datetime__year=2020).count() == 1 + assert await DatetimeFields.filter(datetime__quarter=2).count() == 1 + assert await DatetimeFields.filter(datetime__month=5).count() == 1 + assert await DatetimeFields.filter(datetime__day=20).count() == 1 + if db.db().capabilities.dialect == "mysql": + assert await DatetimeFields.filter(datetime__week=20).count() == 1 + assert await DatetimeFields.filter(datetime__hour=0).count() == 1 + else: + # PostgreSQL enables tzinfo by default + dt = obj.datetime.astimezone() + week = dt.isocalendar()[1] + assert await DatetimeFields.filter(datetime__week=week).count() == 1 + assert await DatetimeFields.filter(datetime__hour=dt.hour).count() == 1 + assert await DatetimeFields.filter(datetime__minute=0).count() == 1 + assert await DatetimeFields.filter(datetime__second=0).count() == 1 + assert await DatetimeFields.filter(datetime__microsecond=0).count() == 1 + + await DateFields.create(date=datetime.date(year=2021, month=6, day=21)) + assert await DateFields.filter(date__year=-2021).count() == 0 + assert await DateFields.filter(date__year=2021).count() == 1 + assert await DateFields.filter(date__month=6).count() == 1 + assert await DateFields.filter(date__day=21).count() == 1 + assert await DateFields.filter(date__year="2021").count() == 1 + assert await DateFields.filter(date__year=2021.0).count() == 1 + assert await DateFields.filter(date="20210621").count() == 1 + assert await DateFields.filter(date="2021-06-21").count() == 1 + assert await DateFields.filter(date=datetime.date(year=2021, month=6, day=21)).count() == 1 + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field(db): + tournament = await Tournament.create(name="0") + await Tournament.create(name="1") + await Event.create(name="2", tournament=tournament) + + tournaments = await Tournament.annotate(events_count=Count("events")).filter(events_count=1) + assert len(tournaments) == 1 + assert tournaments[0].id == tournament.id + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_with_and(db): + tournament = await Tournament.create(name="0") + tournament_second = await Tournament.create(name="1") + await Event.create(name="1", tournament=tournament) + await Event.create(name="2", tournament=tournament_second) + + tournaments = await Tournament.annotate(events_count=Count("events")).filter( + events_count=1, name="0" + ) + assert len(tournaments) == 1 + assert tournaments[0].id == tournament.id + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_with_and_as_one_node(db): + tournament = await Tournament.create(name="0") + tournament_second = await Tournament.create(name="1") + await Event.create(name="1", tournament=tournament) + await Event.create(name="2", tournament=tournament_second) + + tournaments = await Tournament.annotate(events_count=Count("events")).filter( + Q(events_count=1, name="0") + ) + assert len(tournaments) == 1 + assert tournaments[0].id == tournament.id + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_with_and_as_two_nodes(db): + tournament = await Tournament.create(name="0") + tournament_second = await Tournament.create(name="1") + await Event.create(name="1", tournament=tournament) + await Event.create(name="2", tournament=tournament_second) + + tournaments = await Tournament.annotate(events_count=Count("events")).filter( + Q(events_count=1) & Q(name="0") + ) + assert len(tournaments) == 1 + assert tournaments[0].id == tournament.id + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_with_or(db): + tournament = await Tournament.create(name="0") + await Tournament.create(name="1") + await Tournament.create(name="2") + await Event.create(name="1", tournament=tournament) + + tournaments = await Tournament.annotate(events_count=Count("events")).filter( + Q(events_count=1) | Q(name="2") + ) + assert len(tournaments) == 2 + assert {t.name for t in tournaments} == {"0", "2"} + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_with_or_reversed(db): + tournament = await Tournament.create(name="0") + await Tournament.create(name="1") + await Tournament.create(name="2") + await Event.create(name="1", tournament=tournament) + + tournaments = await Tournament.annotate(events_count=Count("events")).filter( + Q(name="2") | Q(events_count=1) + ) + assert len(tournaments) == 2 + assert {t.name for t in tournaments} == {"0", "2"} + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_with_or_as_one_node(db): + tournament = await Tournament.create(name="0") + await Tournament.create(name="1") + await Tournament.create(name="2") + await Event.create(name="1", tournament=tournament) + + tournaments = await Tournament.annotate(events_count=Count("events")).filter( + Q(events_count=1, name="2", join_type=Q.OR) + ) + assert len(tournaments) == 2 + assert {t.name for t in tournaments} == {"0", "2"} + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_with_not(db): + tournament = await Tournament.create(name="0") + tournament_second = await Tournament.create(name="1") + await Event.create(name="1", tournament=tournament) + await Event.create(name="2", tournament=tournament_second) + + tournaments = await Tournament.annotate(events_count=Count("events")).filter( + ~Q(events_count=1, name="0") + ) + assert len(tournaments) == 1 + assert tournaments[0].id == tournament_second.id + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_with_or_not(db): + tournament = await Tournament.create(name="0") + await Tournament.create(name="1") + await Tournament.create(name="2") + await Event.create(name="1", tournament=tournament) + + tournaments = await Tournament.annotate(events_count=Count("events")).filter( + ~(Q(events_count=1) | Q(name="2")) + ) + assert len(tournaments) == 1 + assert {t.name for t in tournaments} == {"1"} + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_with_or_not_reversed(db): + tournament = await Tournament.create(name="0") + await Tournament.create(name="1") + await Tournament.create(name="2") + await Event.create(name="1", tournament=tournament) + + tournaments = await Tournament.annotate(events_count=Count("events")).filter( + ~(Q(name="2") | Q(events_count=1)) + ) + assert len(tournaments) == 1 + assert {t.name for t in tournaments} == {"1"} + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_trim(db): + await Tournament.create(name=" 1 ") + await Tournament.create(name="2 ") + + tournaments = await Tournament.annotate(trimmed_name=Trim("name")).filter(trimmed_name="1") + assert len(tournaments) == 1 + assert {(t.name, t.trimmed_name) for t in tournaments} == {(" 1 ", "1")} + + +@test.requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_length(db): + await Tournament.create(name="12345") + await Tournament.create(name="123") + await Tournament.create(name="1234") + + tournaments = await Tournament.annotate(name_len=Length("name")).filter(name_len__gte=4) + assert len(tournaments) == 2 + assert {t.name for t in tournaments} == {"1234", "12345"} + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_coalesce(db): + await Tournament.create(name="1", desc="demo") + await Tournament.create(name="2") + + tournaments = await Tournament.annotate(clean_desc=Coalesce("desc", "demo")).filter( + clean_desc="demo" + ) + assert len(tournaments) == 2 + assert {(t.name, t.clean_desc) for t in tournaments} == {("1", "demo"), ("2", "demo")} + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_coalesce_numeric(db): + await IntFields.create(intnum=1, intnum_null=10) + await IntFields.create(intnum=4) + + ints = await IntFields.annotate(clean_intnum_null=Coalesce("intnum_null", 0)).filter( + clean_intnum_null__in=(0, 10) + ) + assert len(ints) == 2 + assert {(i.intnum_null, i.clean_intnum_null) for i in ints} == {(None, 0), (10, 10)} + + +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_comparison_coalesce_numeric(db): + await IntFields.create(intnum=3, intnum_null=10) + await IntFields.create(intnum=1, intnum_null=4) + await IntFields.create(intnum=2) + + ints = await IntFields.annotate(clean_intnum_null=Coalesce("intnum_null", 0)).filter( + clean_intnum_null__gt=0 + ) + assert len(ints) == 2 + assert {i.clean_intnum_null for i in ints} == {10, 4} + + +@test.requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_filter_by_aggregation_field_comparison_length(db): + t1 = await Tournament.create(name="Tournament") + await Event.create(name="event1", tournament=t1) + await Event.create(name="event2", tournament=t1) + t2 = await Tournament.create(name="contest") + await Event.create(name="event3", tournament=t2) + await Tournament.create(name="Championship") + t4 = await Tournament.create(name="local") + await Event.create(name="event4", tournament=t4) + await Event.create(name="event5", tournament=t4) + tournaments = await Tournament.annotate( + name_len=Length("name"), event_count=Count("events") + ).filter(name_len__gt=5, event_count=2) + assert len(tournaments) == 1 + assert {t.name for t in tournaments} == {"Tournament"} + + +@pytest.mark.asyncio +async def test_filter_by_annotation_lower(db): + await Tournament.create(name="Tournament") + await Tournament.create(name="NEW Tournament") + tournaments = await Tournament.annotate(name_lower=Lower("name")) + assert len(tournaments) == 2 + assert {t.name_lower for t in tournaments} == {"tournament", "new tournament"} + + +@pytest.mark.asyncio +async def test_filter_by_annotation_upper(db): + await Tournament.create(name="ToUrnAmEnT") + await Tournament.create(name="new TOURnament") + tournaments = await Tournament.annotate(name_upper=Upper("name")) + assert len(tournaments) == 2 + assert {t.name_upper for t in tournaments} == {"TOURNAMENT", "NEW TOURNAMENT"} + + +@pytest.mark.asyncio +async def test_order_by_annotation(db): + t1 = await Tournament.create(name="Tournament") + await Event.create(name="event1", tournament=t1) + await Event.create(name="event2", tournament=t1) + + res = await Event.filter(tournament=t1).annotate(max_id=Max("event_id")).order_by("-event_id") + assert len(res) == 2 + assert res[0].event_id > res[1].event_id + assert res[0].max_id == res[0].event_id + assert res[1].max_id == res[1].event_id + + +@pytest.mark.asyncio +async def test_values_select_relation(db): + with pytest.raises(ValueError): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) + await Event.all().values("tournament") - async def test_filter_by_aggregation_field_with_or_not_reversed(self): - tournament = await Tournament.create(name="0") - await Tournament.create(name="1") - await Tournament.create(name="2") - await Event.create(name="1", tournament=tournament) - tournaments = await Tournament.annotate(events_count=Count("events")).filter( - ~(Q(name="2") | Q(events_count=1)) - ) - self.assertEqual(len(tournaments), 1) - self.assertSetEqual({t.name for t in tournaments}, {"1"}) +@pytest.mark.asyncio +async def test_values_select_relation_field(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) + event_tournaments = await Event.all().values("tournament__name") + assert event_tournaments[0]["tournament__name"] == tournament.name - async def test_filter_by_aggregation_field_trim(self): - await Tournament.create(name=" 1 ") - await Tournament.create(name="2 ") - tournaments = await Tournament.annotate(trimmed_name=Trim("name")).filter(trimmed_name="1") - self.assertEqual(len(tournaments), 1) - self.assertSetEqual({(t.name, t.trimmed_name) for t in tournaments}, {(" 1 ", "1")}) +@pytest.mark.asyncio +async def test_values_select_relation_field_name_override(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) + event_tournaments = await Event.all().values(tour="tournament__name") + assert event_tournaments[0]["tour"] == tournament.name - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_filter_by_aggregation_field_length(self): - await Tournament.create(name="12345") - await Tournament.create(name="123") - await Tournament.create(name="1234") - tournaments = await Tournament.annotate(name_len=Length("name")).filter(name_len__gte=4) - self.assertEqual(len(tournaments), 2) - self.assertSetEqual({t.name for t in tournaments}, {"1234", "12345"}) +@pytest.mark.asyncio +async def test_values_list_select_relation_field(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) + event_tournaments = await Event.all().values_list("tournament__name") + assert event_tournaments[0][0] == tournament.name - async def test_filter_by_aggregation_field_coalesce(self): - await Tournament.create(name="1", desc="demo") - await Tournament.create(name="2") - tournaments = await Tournament.annotate(clean_desc=Coalesce("desc", "demo")).filter( - clean_desc="demo" - ) - self.assertEqual(len(tournaments), 2) - self.assertSetEqual( - {(t.name, t.clean_desc) for t in tournaments}, {("1", "demo"), ("2", "demo")} - ) +@pytest.mark.asyncio +async def test_annotation_in_case_when(db): + await Tournament.create(name="Tournament") + await Tournament.create(name="NEW Tournament") + tournaments = ( + await Tournament.annotate(name_lower=Lower("name")) + .annotate(is_tournament=Case(When(Q(name_lower="tournament"), then="yes"), default="no")) + .filter(is_tournament="yes") + ) + assert len(tournaments) == 1 + assert tournaments[0].name == "Tournament" + assert tournaments[0].name_lower == "tournament" + assert tournaments[0].is_tournament == "yes" - async def test_filter_by_aggregation_field_coalesce_numeric(self): - await IntFields.create(intnum=1, intnum_null=10) - await IntFields.create(intnum=4) - ints = await IntFields.annotate(clean_intnum_null=Coalesce("intnum_null", 0)).filter( - clean_intnum_null__in=(0, 10) - ) - self.assertEqual(len(ints), 2) - self.assertSetEqual( - {(i.intnum_null, i.clean_intnum_null) for i in ints}, {(None, 0), (10, 10)} - ) +@pytest.mark.asyncio +async def test_f_annotation_filter(db): + event = await IntFields.create(intnum=1) - async def test_filter_by_aggregation_field_comparison_coalesce_numeric(self): - await IntFields.create(intnum=3, intnum_null=10) - await IntFields.create(intnum=1, intnum_null=4) - await IntFields.create(intnum=2) + ret_events = await IntFields.annotate(intnum_plus_1=F("intnum") + 1).filter(intnum_plus_1=2) + assert ret_events == [event] - ints = await IntFields.annotate(clean_intnum_null=Coalesce("intnum_null", 0)).filter( - clean_intnum_null__gt=0 - ) - self.assertEqual(len(ints), 2) - self.assertSetEqual({i.clean_intnum_null for i in ints}, {10, 4}) - - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_filter_by_aggregation_field_comparison_length(self): - t1 = await Tournament.create(name="Tournament") - await Event.create(name="event1", tournament=t1) - await Event.create(name="event2", tournament=t1) - t2 = await Tournament.create(name="contest") - await Event.create(name="event3", tournament=t2) - await Tournament.create(name="Championship") - t4 = await Tournament.create(name="local") - await Event.create(name="event4", tournament=t4) - await Event.create(name="event5", tournament=t4) - tournaments = await Tournament.annotate( - name_len=Length("name"), event_count=Count("events") - ).filter(name_len__gt=5, event_count=2) - self.assertEqual(len(tournaments), 1) - self.assertSetEqual({t.name for t in tournaments}, {"Tournament"}) - - async def test_filter_by_annotation_lower(self): - await Tournament.create(name="Tournament") - await Tournament.create(name="NEW Tournament") - tournaments = await Tournament.annotate(name_lower=Lower("name")) - self.assertEqual(len(tournaments), 2) - self.assertSetEqual({t.name_lower for t in tournaments}, {"tournament", "new tournament"}) - - async def test_filter_by_annotation_upper(self): - await Tournament.create(name="ToUrnAmEnT") - await Tournament.create(name="new TOURnament") - tournaments = await Tournament.annotate(name_upper=Upper("name")) - self.assertEqual(len(tournaments), 2) - self.assertSetEqual({t.name_upper for t in tournaments}, {"TOURNAMENT", "NEW TOURNAMENT"}) - - async def test_order_by_annotation(self): - t1 = await Tournament.create(name="Tournament") - await Event.create(name="event1", tournament=t1) - await Event.create(name="event2", tournament=t1) - - res = ( - await Event.filter(tournament=t1).annotate(max_id=Max("event_id")).order_by("-event_id") - ) - self.assertEqual(len(res), 2) - self.assertGreater(res[0].event_id, res[1].event_id) - self.assertEqual(res[0].max_id, res[0].event_id) - self.assertEqual(res[1].max_id, res[1].event_id) - - async def test_values_select_relation(self): - with self.assertRaises(ValueError): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) - await Event.all().values("tournament") - - async def test_values_select_relation_field(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) - event_tournaments = await Event.all().values("tournament__name") - self.assertEqual(event_tournaments[0]["tournament__name"], tournament.name) - async def test_values_select_relation_field_name_override(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) - event_tournaments = await Event.all().values(tour="tournament__name") - self.assertEqual(event_tournaments[0]["tour"], tournament.name) +@pytest.mark.asyncio +async def test_f_annotation_custom_filter(db): + event = await IntFields.create(intnum=1) - async def test_values_list_select_relation_field(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) - event_tournaments = await Event.all().values_list("tournament__name") - self.assertEqual(event_tournaments[0][0], tournament.name) - - async def test_annotation_in_case_when(self): - await Tournament.create(name="Tournament") - await Tournament.create(name="NEW Tournament") - tournaments = ( - await Tournament.annotate(name_lower=Lower("name")) - .annotate( - is_tournament=Case(When(Q(name_lower="tournament"), then="yes"), default="no") - ) - .filter(is_tournament="yes") - ) - self.assertEqual(len(tournaments), 1) - self.assertEqual(tournaments[0].name, "Tournament") - self.assertEqual(tournaments[0].name_lower, "tournament") - self.assertEqual(tournaments[0].is_tournament, "yes") + base_query = IntFields.annotate(intnum_plus_1=F("intnum") + 1) - async def test_f_annotation_filter(self): - event = await IntFields.create(intnum=1) + ret_events = await base_query.filter(intnum_plus_1__gt=1) + assert ret_events == [event] - ret_events = await IntFields.annotate(intnum_plus_1=F("intnum") + 1).filter(intnum_plus_1=2) - self.assertEqual(ret_events, [event]) + ret_events = await base_query.filter(intnum_plus_1__lt=3) + assert ret_events == [event] - async def test_f_annotation_custom_filter(self): - event = await IntFields.create(intnum=1) + ret_events = await base_query.filter(Q(intnum_plus_1__gt=1) & Q(intnum_plus_1__lt=3)) + assert ret_events == [event] - base_query = IntFields.annotate(intnum_plus_1=F("intnum") + 1) + ret_events = await base_query.filter(intnum_plus_1__isnull=True) + assert ret_events == [] - ret_events = await base_query.filter(intnum_plus_1__gt=1) - self.assertEqual(ret_events, [event]) - ret_events = await base_query.filter(intnum_plus_1__lt=3) - self.assertEqual(ret_events, [event]) +@pytest.mark.asyncio +async def test_f_annotation_join(db): + tournament_a = await Tournament.create(name="A") + tournament_b = await Tournament.create(name="B") + await Tournament.create(name="C") + event_a = await Event.create(name="A", tournament=tournament_a) + await Event.create(name="B", tournament=tournament_b) - ret_events = await base_query.filter(Q(intnum_plus_1__gt=1) & Q(intnum_plus_1__lt=3)) - self.assertEqual(ret_events, [event]) + events = ( + await Event.all() + .annotate(tournament_name=F("tournament__name")) + .filter(tournament_name="A") + ) + assert events == [event_a] - ret_events = await base_query.filter(intnum_plus_1__isnull=True) - self.assertEqual(ret_events, []) - async def test_f_annotation_join(self): - tournament_a = await Tournament.create(name="A") - tournament_b = await Tournament.create(name="B") - await Tournament.create(name="C") - event_a = await Event.create(name="A", tournament=tournament_a) - await Event.create(name="B", tournament=tournament_b) +@pytest.mark.asyncio +async def test_f_annotation_custom_filter_requiring_join(db): + tournament_a = await Tournament.create(name="A") + tournament_b = await Tournament.create(name="B") + await Tournament.create(name="C") + await Event.create(name="A", tournament=tournament_a) + event_b = await Event.create(name="B", tournament=tournament_b) - events = ( - await Event.all() - .annotate(tournament_name=F("tournament__name")) - .filter(tournament_name="A") - ) - self.assertEqual(events, [event_a]) - - async def test_f_annotation_custom_filter_requiring_join(self): - tournament_a = await Tournament.create(name="A") - tournament_b = await Tournament.create(name="B") - await Tournament.create(name="C") - await Event.create(name="A", tournament=tournament_a) - event_b = await Event.create(name="B", tournament=tournament_b) - - events = ( - await Event.all() - .annotate(tournament_name=F("tournament__name")) - .filter(tournament_name__gt="A") - ) - self.assertEqual(events, [event_b]) + events = ( + await Event.all() + .annotate(tournament_name=F("tournament__name")) + .filter(tournament_name__gt="A") + ) + assert events == [event_b] - async def test_f_annotation_custom_filter_requiring_nested_joins(self): - tournament = await Tournament.create(name="Tournament") - second_tournament = await Tournament.create(name="Tournament 2") +@pytest.mark.asyncio +async def test_f_annotation_custom_filter_requiring_nested_joins(db): + tournament = await Tournament.create(name="Tournament") - event_first = await Event.create(name="1", tournament=tournament) - event_second = await Event.create(name="2", tournament=second_tournament) - await Event.create(name="3", tournament=tournament) - await Event.create(name="4", tournament=second_tournament) + second_tournament = await Tournament.create(name="Tournament 2") - team_first = await Team.create(name="First") - team_second = await Team.create(name="Second") + event_first = await Event.create(name="1", tournament=tournament) + event_second = await Event.create(name="2", tournament=second_tournament) + await Event.create(name="3", tournament=tournament) + await Event.create(name="4", tournament=second_tournament) - await team_first.events.add(event_first) - await event_second.participants.add(team_second) + team_first = await Team.create(name="First") + team_second = await Team.create(name="Second") - res = await Tournament.annotate(pname=F("events__participants__name")).filter( - pname__startswith="Fir" - ) - self.assertEqual(res, [tournament]) + await team_first.events.add(event_first) + await event_second.participants.add(team_second) + + res = await Tournament.annotate(pname=F("events__participants__name")).filter( + pname__startswith="Fir" + ) + assert res == [tournament] diff --git a/tests/test_filters.py b/tests/test_filters.py index 3454c20da..bb137bea4 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,6 +1,9 @@ from decimal import Decimal from enum import Enum +import pytest +import pytest_asyncio + from tests.testmodels import ( BooleanFields, CharFields, @@ -8,7 +11,6 @@ CharPkModel, DecimalFields, ) -from tortoise.contrib import test from tortoise.exceptions import FieldError from tortoise.fields.base import StrEnum @@ -21,363 +23,394 @@ class MyStrEnum(StrEnum): moo = "moo" -class TestCharFieldFilters(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - await CharFields.create(char="moo") - await CharFields.create(char="baa", char_null="baa") - await CharFields.create(char="oink") - - async def test_bad_param(self): - with self.assertRaisesRegex( - FieldError, "Unknown filter param 'charup'. Allowed base values are" - ): - await CharFields.filter(charup="moo") - - async def test_equal(self): - self.assertEqual( - set(await CharFields.filter(char="moo").values_list("char", flat=True)), {"moo"} - ) - - async def test_enum(self): - self.assertEqual( - set(await CharFields.filter(char=MyEnum.moo).values_list("char", flat=True)), {"moo"} - ) - self.assertEqual( - set(await CharFields.filter(char=MyStrEnum.moo).values_list("char", flat=True)), {"moo"} - ) - - async def test_not(self): - self.assertEqual( - set(await CharFields.filter(char__not="moo").values_list("char", flat=True)), - {"baa", "oink"}, - ) - - async def test_in(self): - self.assertSetEqual( - set(await CharFields.filter(char__in=["moo", "baa"]).values_list("char", flat=True)), - {"moo", "baa"}, - ) - - async def test_in_empty(self): - self.assertEqual( - await CharFields.filter(char__in=[]).values_list("char", flat=True), - [], - ) - - async def test_not_in(self): - self.assertSetEqual( - set( - await CharFields.filter(char__not_in=["moo", "baa"]).values_list("char", flat=True) - ), - {"oink"}, - ) - - async def test_not_in_empty(self): - self.assertSetEqual( - set(await CharFields.filter(char__not_in=[]).values_list("char", flat=True)), - {"oink", "moo", "baa"}, - ) - - async def test_isnull(self): - self.assertSetEqual( - set(await CharFields.filter(char_null__isnull=True).values_list("char", flat=True)), - {"moo", "oink"}, - ) - self.assertSetEqual( - set(await CharFields.filter(char_null__isnull=False).values_list("char", flat=True)), - {"baa"}, - ) - - async def test_not_isnull(self): - self.assertSetEqual( - set(await CharFields.filter(char_null__not_isnull=True).values_list("char", flat=True)), - {"baa"}, - ) - self.assertSetEqual( - set( - await CharFields.filter(char_null__not_isnull=False).values_list("char", flat=True) - ), - {"moo", "oink"}, - ) - - async def test_gte(self): - self.assertSetEqual( - set(await CharFields.filter(char__gte="moo").values_list("char", flat=True)), - {"moo", "oink"}, - ) - - async def test_lte(self): - self.assertSetEqual( - set(await CharFields.filter(char__lte="moo").values_list("char", flat=True)), - {"moo", "baa"}, - ) - - async def test_gt(self): - self.assertSetEqual( - set(await CharFields.filter(char__gt="moo").values_list("char", flat=True)), {"oink"} - ) - - async def test_lt(self): - self.assertSetEqual( - set(await CharFields.filter(char__lt="moo").values_list("char", flat=True)), {"baa"} - ) - - async def test_contains(self): - self.assertSetEqual( - set(await CharFields.filter(char__contains="o").values_list("char", flat=True)), - {"moo", "oink"}, - ) - - async def test_startswith(self): - self.assertSetEqual( - set(await CharFields.filter(char__startswith="m").values_list("char", flat=True)), - {"moo"}, - ) - self.assertSetEqual( - set(await CharFields.filter(char__startswith="s").values_list("char", flat=True)), set() - ) - - async def test_endswith(self): - self.assertSetEqual( - set(await CharFields.filter(char__endswith="o").values_list("char", flat=True)), {"moo"} - ) - self.assertSetEqual( - set(await CharFields.filter(char__endswith="s").values_list("char", flat=True)), set() - ) - - async def test_icontains(self): - self.assertSetEqual( - set(await CharFields.filter(char__icontains="oO").values_list("char", flat=True)), - {"moo"}, - ) - self.assertSetEqual( - set(await CharFields.filter(char__icontains="Oo").values_list("char", flat=True)), - {"moo"}, - ) - - async def test_iexact(self): - self.assertSetEqual( - set(await CharFields.filter(char__iexact="MoO").values_list("char", flat=True)), {"moo"} - ) - - async def test_istartswith(self): - self.assertSetEqual( - set(await CharFields.filter(char__istartswith="m").values_list("char", flat=True)), - {"moo"}, - ) - self.assertSetEqual( - set(await CharFields.filter(char__istartswith="M").values_list("char", flat=True)), - {"moo"}, - ) - - async def test_iendswith(self): - self.assertSetEqual( - set(await CharFields.filter(char__iendswith="oO").values_list("char", flat=True)), - {"moo"}, - ) - self.assertSetEqual( - set(await CharFields.filter(char__iendswith="Oo").values_list("char", flat=True)), - {"moo"}, - ) - - async def test_sorting(self): - self.assertEqual( - await CharFields.all().order_by("char").values_list("char", flat=True), - ["baa", "moo", "oink"], - ) - - -class TestBooleanFieldFilters(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - await BooleanFields.create(boolean=True) - await BooleanFields.create(boolean=False) - await BooleanFields.create(boolean=True, boolean_null=True) - await BooleanFields.create(boolean=False, boolean_null=True) - await BooleanFields.create(boolean=True, boolean_null=False) - await BooleanFields.create(boolean=False, boolean_null=False) - - async def test_equal_true(self): - self.assertEqual( - set(await BooleanFields.filter(boolean=True).values_list("boolean", "boolean_null")), - {(True, None), (True, True), (True, False)}, - ) - - async def test_equal_false(self): - self.assertEqual( - set(await BooleanFields.filter(boolean=False).values_list("boolean", "boolean_null")), - {(False, None), (False, True), (False, False)}, - ) - - async def test_equal_true2(self): - self.assertEqual( - set( - await BooleanFields.filter(boolean_null=True).values_list("boolean", "boolean_null") - ), - {(False, True), (True, True)}, - ) - - async def test_equal_false2(self): - self.assertEqual( - set( - await BooleanFields.filter(boolean_null=False).values_list( - "boolean", "boolean_null" - ) - ), - {(False, False), (True, False)}, - ) - - async def test_equal_null(self): - self.assertEqual( - set( - await BooleanFields.filter(boolean_null=None).values_list("boolean", "boolean_null") - ), - {(False, None), (True, None)}, - ) - - -class TestDecimalFieldFilters(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - await DecimalFields.create(decimal="1.2345", decimal_nodec=1) - await DecimalFields.create(decimal="2.34567", decimal_nodec=1) - await DecimalFields.create(decimal="2.300", decimal_nodec=1) - await DecimalFields.create(decimal="023.0", decimal_nodec=1) - await DecimalFields.create(decimal="0.230", decimal_nodec=1) - - async def test_sorting(self): - self.assertEqual( - await DecimalFields.all().order_by("decimal").values_list("decimal", flat=True), - [Decimal("0.23"), Decimal("1.2345"), Decimal("2.3"), Decimal("2.3457"), Decimal("23")], - ) - - async def test_gt(self): - self.assertEqual( - await DecimalFields.filter(decimal__gt=Decimal("1.2345")) - .order_by("decimal") - .values_list("decimal", flat=True), - [Decimal("2.3"), Decimal("2.3457"), Decimal("23")], - ) - - async def test_between_and(self): - self.assertEqual( - await DecimalFields.filter( - decimal__range=(Decimal("1.2344"), Decimal("1.2346")) - ).values_list("decimal", flat=True), - [Decimal("1.2345")], - ) - - async def test_in(self): - self.assertEqual( - await DecimalFields.filter( - decimal__in=[Decimal("1.2345"), Decimal("1000")] - ).values_list("decimal", flat=True), - [Decimal("1.2345")], - ) - - -class TestCharFkFieldFilters(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - model1 = await CharPkModel.create(id=17) - model2 = await CharPkModel.create(id=12) - await CharPkModel.create(id=2001) - await CharFkRelatedModel.create(model=model1) - await CharFkRelatedModel.create(model=model1) - await CharFkRelatedModel.create(model=model2) - - async def test_bad_param(self): - with self.assertRaisesRegex( - FieldError, "Unknown filter param 'bad_param'. Allowed base values are" - ): - await CharPkModel.filter(bad_param="moo") - - async def test_equal(self): - self.assertEqual( - set(await CharPkModel.filter(id=2001).values_list("id", flat=True)), {"2001"} - ) - - async def test_not(self): - self.assertEqual( - set(await CharPkModel.filter(id__not=2001).values_list("id", flat=True)), - {"17", "12"}, - ) - - async def test_in(self): - self.assertSetEqual( - set(await CharPkModel.filter(id__in=[17, 12]).values_list("id", flat=True)), - {"17", "12"}, - ) - - async def test_in_empty(self): - self.assertEqual( - await CharPkModel.filter(id__in=[]).values_list("id", flat=True), - [], - ) - - async def test_not_in(self): - self.assertSetEqual( - set(await CharPkModel.filter(id__not_in=[17, 12]).values_list("id", flat=True)), - {"2001"}, - ) - - async def test_not_in_empty(self): - self.assertSetEqual( - set(await CharPkModel.filter(id__not_in=[]).values_list("id", flat=True)), - {"17", "12", "2001"}, - ) - - async def test_isnull(self): - self.assertSetEqual( - set(await CharPkModel.filter(children__isnull=True).values_list("id", flat=True)), - {"2001"}, - ) - self.assertEqual( - await CharPkModel.filter(children__isnull=False) - .order_by("id") - .values_list("id", flat=True), - ["12", "17", "17"], - ) - - async def test_not_isnull(self): - self.assertSetEqual( - set(await CharPkModel.filter(children__not_isnull=True).values_list("id", flat=True)), - {"17", "12"}, - ) - self.assertSetEqual( - set(await CharPkModel.filter(children__not_isnull=False).values_list("id", flat=True)), - {"2001"}, - ) - - async def test_gte(self): - self.assertSetEqual( - set(await CharPkModel.filter(id__gte=17).values_list("id", flat=True)), - {"17", "2001"}, - ) - - async def test_lte(self): - self.assertSetEqual( - set(await CharPkModel.filter(id__lte=17).values_list("id", flat=True)), - {"12", "17"}, - ) - - async def test_gt(self): - self.assertSetEqual( - set(await CharPkModel.filter(id__gt=17).values_list("id", flat=True)), {"2001"} - ) - - async def test_lt(self): - self.assertSetEqual( - set(await CharPkModel.filter(id__lt=17).values_list("id", flat=True)), {"12"} - ) - - async def test_sorting(self): - self.assertEqual( - await CharPkModel.all().order_by("id").values_list("id", flat=True), - ["12", "17", "2001"], - ) - self.assertEqual( - await CharPkModel.all().order_by("-id").values_list("id", flat=True), - ["2001", "17", "12"], - ) +# --- CharFields tests --- + + +@pytest_asyncio.fixture +async def char_fields_data(db): + await CharFields.create(char="moo") + await CharFields.create(char="baa", char_null="baa") + await CharFields.create(char="oink") + + +@pytest.mark.asyncio +async def test_char_field_bad_param(db, char_fields_data): + with pytest.raises(FieldError, match="Unknown filter param 'charup'. Allowed base values are"): + await CharFields.filter(charup="moo") + + +@pytest.mark.asyncio +async def test_char_field_equal(db, char_fields_data): + assert set(await CharFields.filter(char="moo").values_list("char", flat=True)) == {"moo"} + + +@pytest.mark.asyncio +async def test_char_field_enum(db, char_fields_data): + assert set(await CharFields.filter(char=MyEnum.moo).values_list("char", flat=True)) == {"moo"} + assert set(await CharFields.filter(char=MyStrEnum.moo).values_list("char", flat=True)) == { + "moo" + } + + +@pytest.mark.asyncio +async def test_char_field_not(db, char_fields_data): + assert set(await CharFields.filter(char__not="moo").values_list("char", flat=True)) == { + "baa", + "oink", + } + + +@pytest.mark.asyncio +async def test_char_field_in(db, char_fields_data): + assert set(await CharFields.filter(char__in=["moo", "baa"]).values_list("char", flat=True)) == { + "moo", + "baa", + } + + +@pytest.mark.asyncio +async def test_char_field_in_empty(db, char_fields_data): + assert await CharFields.filter(char__in=[]).values_list("char", flat=True) == [] + + +@pytest.mark.asyncio +async def test_char_field_not_in(db, char_fields_data): + assert set( + await CharFields.filter(char__not_in=["moo", "baa"]).values_list("char", flat=True) + ) == {"oink"} + + +@pytest.mark.asyncio +async def test_char_field_not_in_empty(db, char_fields_data): + assert set(await CharFields.filter(char__not_in=[]).values_list("char", flat=True)) == { + "oink", + "moo", + "baa", + } + + +@pytest.mark.asyncio +async def test_char_field_isnull(db, char_fields_data): + assert set(await CharFields.filter(char_null__isnull=True).values_list("char", flat=True)) == { + "moo", + "oink", + } + assert set(await CharFields.filter(char_null__isnull=False).values_list("char", flat=True)) == { + "baa" + } + + +@pytest.mark.asyncio +async def test_char_field_not_isnull(db, char_fields_data): + assert set( + await CharFields.filter(char_null__not_isnull=True).values_list("char", flat=True) + ) == {"baa"} + assert set( + await CharFields.filter(char_null__not_isnull=False).values_list("char", flat=True) + ) == {"moo", "oink"} + + +@pytest.mark.asyncio +async def test_char_field_gte(db, char_fields_data): + assert set(await CharFields.filter(char__gte="moo").values_list("char", flat=True)) == { + "moo", + "oink", + } + + +@pytest.mark.asyncio +async def test_char_field_lte(db, char_fields_data): + assert set(await CharFields.filter(char__lte="moo").values_list("char", flat=True)) == { + "moo", + "baa", + } + + +@pytest.mark.asyncio +async def test_char_field_gt(db, char_fields_data): + assert set(await CharFields.filter(char__gt="moo").values_list("char", flat=True)) == {"oink"} + + +@pytest.mark.asyncio +async def test_char_field_lt(db, char_fields_data): + assert set(await CharFields.filter(char__lt="moo").values_list("char", flat=True)) == {"baa"} + + +@pytest.mark.asyncio +async def test_char_field_contains(db, char_fields_data): + assert set(await CharFields.filter(char__contains="o").values_list("char", flat=True)) == { + "moo", + "oink", + } + + +@pytest.mark.asyncio +async def test_char_field_startswith(db, char_fields_data): + assert set(await CharFields.filter(char__startswith="m").values_list("char", flat=True)) == { + "moo" + } + assert ( + set(await CharFields.filter(char__startswith="s").values_list("char", flat=True)) == set() + ) + + +@pytest.mark.asyncio +async def test_char_field_endswith(db, char_fields_data): + assert set(await CharFields.filter(char__endswith="o").values_list("char", flat=True)) == { + "moo" + } + assert set(await CharFields.filter(char__endswith="s").values_list("char", flat=True)) == set() + + +@pytest.mark.asyncio +async def test_char_field_icontains(db, char_fields_data): + assert set(await CharFields.filter(char__icontains="oO").values_list("char", flat=True)) == { + "moo" + } + assert set(await CharFields.filter(char__icontains="Oo").values_list("char", flat=True)) == { + "moo" + } + + +@pytest.mark.asyncio +async def test_char_field_iexact(db, char_fields_data): + assert set(await CharFields.filter(char__iexact="MoO").values_list("char", flat=True)) == { + "moo" + } + + +@pytest.mark.asyncio +async def test_char_field_istartswith(db, char_fields_data): + assert set(await CharFields.filter(char__istartswith="m").values_list("char", flat=True)) == { + "moo" + } + assert set(await CharFields.filter(char__istartswith="M").values_list("char", flat=True)) == { + "moo" + } + + +@pytest.mark.asyncio +async def test_char_field_iendswith(db, char_fields_data): + assert set(await CharFields.filter(char__iendswith="oO").values_list("char", flat=True)) == { + "moo" + } + assert set(await CharFields.filter(char__iendswith="Oo").values_list("char", flat=True)) == { + "moo" + } + + +@pytest.mark.asyncio +async def test_char_field_sorting(db, char_fields_data): + assert await CharFields.all().order_by("char").values_list("char", flat=True) == [ + "baa", + "moo", + "oink", + ] + + +# --- BooleanFields tests --- + + +@pytest_asyncio.fixture +async def boolean_fields_data(db): + await BooleanFields.create(boolean=True) + await BooleanFields.create(boolean=False) + await BooleanFields.create(boolean=True, boolean_null=True) + await BooleanFields.create(boolean=False, boolean_null=True) + await BooleanFields.create(boolean=True, boolean_null=False) + await BooleanFields.create(boolean=False, boolean_null=False) + + +@pytest.mark.asyncio +async def test_boolean_field_equal_true(db, boolean_fields_data): + assert set(await BooleanFields.filter(boolean=True).values_list("boolean", "boolean_null")) == { + (True, None), + (True, True), + (True, False), + } + + +@pytest.mark.asyncio +async def test_boolean_field_equal_false(db, boolean_fields_data): + assert set( + await BooleanFields.filter(boolean=False).values_list("boolean", "boolean_null") + ) == {(False, None), (False, True), (False, False)} + + +@pytest.mark.asyncio +async def test_boolean_field_equal_true2(db, boolean_fields_data): + assert set( + await BooleanFields.filter(boolean_null=True).values_list("boolean", "boolean_null") + ) == {(False, True), (True, True)} + + +@pytest.mark.asyncio +async def test_boolean_field_equal_false2(db, boolean_fields_data): + assert set( + await BooleanFields.filter(boolean_null=False).values_list("boolean", "boolean_null") + ) == {(False, False), (True, False)} + + +@pytest.mark.asyncio +async def test_boolean_field_equal_null(db, boolean_fields_data): + assert set( + await BooleanFields.filter(boolean_null=None).values_list("boolean", "boolean_null") + ) == {(False, None), (True, None)} + + +# --- DecimalFields tests --- + + +@pytest_asyncio.fixture +async def decimal_fields_data(db): + await DecimalFields.create(decimal="1.2345", decimal_nodec=1) + await DecimalFields.create(decimal="2.34567", decimal_nodec=1) + await DecimalFields.create(decimal="2.300", decimal_nodec=1) + await DecimalFields.create(decimal="023.0", decimal_nodec=1) + await DecimalFields.create(decimal="0.230", decimal_nodec=1) + + +@pytest.mark.asyncio +async def test_decimal_field_sorting(db, decimal_fields_data): + assert await DecimalFields.all().order_by("decimal").values_list("decimal", flat=True) == [ + Decimal("0.23"), + Decimal("1.2345"), + Decimal("2.3"), + Decimal("2.3457"), + Decimal("23"), + ] + + +@pytest.mark.asyncio +async def test_decimal_field_gt(db, decimal_fields_data): + assert await DecimalFields.filter(decimal__gt=Decimal("1.2345")).order_by( + "decimal" + ).values_list("decimal", flat=True) == [Decimal("2.3"), Decimal("2.3457"), Decimal("23")] + + +@pytest.mark.asyncio +async def test_decimal_field_between_and(db, decimal_fields_data): + assert await DecimalFields.filter( + decimal__range=(Decimal("1.2344"), Decimal("1.2346")) + ).values_list("decimal", flat=True) == [Decimal("1.2345")] + + +@pytest.mark.asyncio +async def test_decimal_field_in(db, decimal_fields_data): + assert await DecimalFields.filter(decimal__in=[Decimal("1.2345"), Decimal("1000")]).values_list( + "decimal", flat=True + ) == [Decimal("1.2345")] + + +# --- CharPkModel / CharFkRelatedModel tests --- + + +@pytest_asyncio.fixture +async def char_fk_data(db): + model1 = await CharPkModel.create(id=17) + model2 = await CharPkModel.create(id=12) + await CharPkModel.create(id=2001) + await CharFkRelatedModel.create(model=model1) + await CharFkRelatedModel.create(model=model1) + await CharFkRelatedModel.create(model=model2) + + +@pytest.mark.asyncio +async def test_char_fk_bad_param(db, char_fk_data): + with pytest.raises( + FieldError, match="Unknown filter param 'bad_param'. Allowed base values are" + ): + await CharPkModel.filter(bad_param="moo") + + +@pytest.mark.asyncio +async def test_char_fk_equal(db, char_fk_data): + assert set(await CharPkModel.filter(id=2001).values_list("id", flat=True)) == {"2001"} + + +@pytest.mark.asyncio +async def test_char_fk_not(db, char_fk_data): + assert set(await CharPkModel.filter(id__not=2001).values_list("id", flat=True)) == {"17", "12"} + + +@pytest.mark.asyncio +async def test_char_fk_in(db, char_fk_data): + assert set(await CharPkModel.filter(id__in=[17, 12]).values_list("id", flat=True)) == { + "17", + "12", + } + + +@pytest.mark.asyncio +async def test_char_fk_in_empty(db, char_fk_data): + assert await CharPkModel.filter(id__in=[]).values_list("id", flat=True) == [] + + +@pytest.mark.asyncio +async def test_char_fk_not_in(db, char_fk_data): + assert set(await CharPkModel.filter(id__not_in=[17, 12]).values_list("id", flat=True)) == { + "2001" + } + + +@pytest.mark.asyncio +async def test_char_fk_not_in_empty(db, char_fk_data): + assert set(await CharPkModel.filter(id__not_in=[]).values_list("id", flat=True)) == { + "17", + "12", + "2001", + } + + +@pytest.mark.asyncio +async def test_char_fk_isnull(db, char_fk_data): + assert set(await CharPkModel.filter(children__isnull=True).values_list("id", flat=True)) == { + "2001" + } + assert await CharPkModel.filter(children__isnull=False).order_by("id").values_list( + "id", flat=True + ) == ["12", "17", "17"] + + +@pytest.mark.asyncio +async def test_char_fk_not_isnull(db, char_fk_data): + assert set( + await CharPkModel.filter(children__not_isnull=True).values_list("id", flat=True) + ) == {"17", "12"} + assert set( + await CharPkModel.filter(children__not_isnull=False).values_list("id", flat=True) + ) == {"2001"} + + +@pytest.mark.asyncio +async def test_char_fk_gte(db, char_fk_data): + assert set(await CharPkModel.filter(id__gte=17).values_list("id", flat=True)) == {"17", "2001"} + + +@pytest.mark.asyncio +async def test_char_fk_lte(db, char_fk_data): + assert set(await CharPkModel.filter(id__lte=17).values_list("id", flat=True)) == {"12", "17"} + + +@pytest.mark.asyncio +async def test_char_fk_gt(db, char_fk_data): + assert set(await CharPkModel.filter(id__gt=17).values_list("id", flat=True)) == {"2001"} + + +@pytest.mark.asyncio +async def test_char_fk_lt(db, char_fk_data): + assert set(await CharPkModel.filter(id__lt=17).values_list("id", flat=True)) == {"12"} + + +@pytest.mark.asyncio +async def test_char_fk_sorting(db, char_fk_data): + assert await CharPkModel.all().order_by("id").values_list("id", flat=True) == [ + "12", + "17", + "2001", + ] + assert await CharPkModel.all().order_by("-id").values_list("id", flat=True) == [ + "2001", + "17", + "12", + ] diff --git a/tests/test_fuzz.py b/tests/test_fuzz.py index d6bd10633..b8bb536f3 100644 --- a/tests/test_fuzz.py +++ b/tests/test_fuzz.py @@ -1,3 +1,5 @@ +import pytest + from tests.testmodels import CharFields from tortoise.contrib import test from tortoise.contrib.test.condition import NotEQ @@ -15,9 +17,9 @@ "''", "\\_", "\\\\_", - "‘a", - "a’", - "‘a’", + "'a", + "a'", + "'a'", "a/a", "a\\a", "0x39", @@ -101,51 +103,54 @@ ] -class TestFuzz(test.TestCase): - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_char_fuzz(self): - for char in DODGY_STRINGS: - # print(repr(char)) - if "\x00" in char and self._db.capabilities.dialect in ["postgres"]: - # PostgreSQL doesn't support null values as text. Ever. So skip these. - continue +@pytest.mark.asyncio +@test.requireCapability(dialect=NotEQ("mssql")) +async def test_char_fuzz(db): + """Test character field handling with various dodgy/edge-case strings.""" + conn = db.db() + + for char in DODGY_STRINGS: + # print(repr(char)) + if "\x00" in char and conn.capabilities.dialect in ["postgres"]: + # PostgreSQL doesn't support null values as text. Ever. So skip these. + continue - # Create - obj1 = await CharFields.create(char=char) + # Create + obj1 = await CharFields.create(char=char) - # Get-by-pk, and confirm that reading is correct - obj2 = await CharFields.get(pk=obj1.pk) - self.assertEqual(char, obj2.char) + # Get-by-pk, and confirm that reading is correct + obj2 = await CharFields.get(pk=obj1.pk) + assert char == obj2.char - # Update data using a queryset, confirm that update is correct - await CharFields.filter(pk=obj1.pk).update(char="a") - await CharFields.filter(pk=obj1.pk).update(char=char) - obj3 = await CharFields.get(pk=obj1.pk) - self.assertEqual(char, obj3.char) + # Update data using a queryset, confirm that update is correct + await CharFields.filter(pk=obj1.pk).update(char="a") + await CharFields.filter(pk=obj1.pk).update(char=char) + obj3 = await CharFields.get(pk=obj1.pk) + assert char == obj3.char - # Filter by value in queryset, and confirm that it fetched the right one - obj4 = await CharFields.get(pk=obj1.pk, char=char) - self.assertEqual(obj1.pk, obj4.pk) - self.assertEqual(char, obj4.char) + # Filter by value in queryset, and confirm that it fetched the right one + obj4 = await CharFields.get(pk=obj1.pk, char=char) + assert obj1.pk == obj4.pk + assert char == obj4.char - # LIKE statements are not strict, so require all of these to match - obj5 = await CharFields.get( - pk=obj1.pk, - char__startswith=char, - char__endswith=char, - char__contains=char, - char__istartswith=char, - char__iendswith=char, - char__icontains=char, - ) - self.assertEqual(obj1.pk, obj5.pk) - self.assertEqual(char, obj5.char) + # LIKE statements are not strict, so require all of these to match + obj5 = await CharFields.get( + pk=obj1.pk, + char__startswith=char, + char__endswith=char, + char__contains=char, + char__istartswith=char, + char__iendswith=char, + char__icontains=char, + ) + assert obj1.pk == obj5.pk + assert char == obj5.char - # Filter by a function - obj6 = ( - await CharFields.annotate(upper_char=Upper("char")) - .filter(id=obj1.pk, upper_char=Upper("char")) - .first() - ) - self.assertEqual(obj1.pk, obj6.pk) - self.assertEqual(char, obj6.char) + # Filter by a function + obj6 = ( + await CharFields.annotate(upper_char=Upper("char")) + .filter(id=obj1.pk, upper_char=Upper("char")) + .first() + ) + assert obj1.pk == obj6.pk + assert char == obj6.char diff --git a/tests/test_group_by.py b/tests/test_group_by.py index 14a7c7cb6..efb9ad758 100644 --- a/tests/test_group_by.py +++ b/tests/test_group_by.py @@ -1,331 +1,369 @@ +import pytest +import pytest_asyncio + from tests.testmodels import Author, Book, Event, Team, Tournament -from tortoise.contrib import test from tortoise.expressions import Subquery from tortoise.functions import Avg, Count, Sum, Upper -class TestGroupBy(test.TestCase): - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - self.a1 = await Author.create(name="author1") - self.a2 = await Author.create(name="author2") - self.books1 = [ - await Book.create(name=f"book{i}", author=self.a1, rating=i) for i in range(10) - ] - self.books2 = [ - await Book.create(name=f"book{i}", author=self.a2, rating=i) for i in range(5) - ] - - async def test_count_group_by(self): - ret = ( - await Book.annotate(count=Count("id")) - .group_by("author_id") - .values("author_id", "count") - ) - - for item in ret: - author_id = item.get("author_id") - count = item.get("count") - if author_id == self.a1.pk: - self.assertEqual(count, 10) - elif author_id == self.a2.pk: - self.assertEqual(count, 5) - - async def test_count_group_by_with_join(self): - ret = ( - await Book.annotate(count=Count("id")) - .group_by("author__name") - .values("author__name", "count") - ) - self.assertListSortEqual( - ret, - [{"author__name": "author1", "count": 10}, {"author__name": "author2", "count": 5}], - sorted_key="author__name", - ) - - async def test_count_filter_group_by(self): - ret = ( - await Book.annotate(count=Count("id")) - .filter(count__gt=6) - .group_by("author_id") - .values("author_id", "count") - ) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0].get("count"), 10) - - async def test_sum_group_by(self): - ret = ( - await Book.annotate(sum=Sum("rating")).group_by("author_id").values("author_id", "sum") - ) - for item in ret: - author_id = item.get("author_id") - sum_ = item.get("sum") - if author_id == self.a1.pk: - self.assertEqual(sum_, 45.0) - elif author_id == self.a2.pk: - self.assertEqual(sum_, 10.0) - - async def test_sum_group_by_with_join(self): - ret = ( - await Book.annotate(sum=Sum("rating")) - .group_by("author__name") - .values("author__name", "sum") - ) - self.assertListSortEqual( - ret, - [{"author__name": "author1", "sum": 45.0}, {"author__name": "author2", "sum": 10.0}], - sorted_key="author__name", - ) - - async def test_sum_filter_group_by(self): - ret = ( - await Book.annotate(sum=Sum("rating")) - .filter(sum__gt=11) - .group_by("author_id") - .values("author_id", "sum") - ) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0].get("sum"), 45.0) - - async def test_avg_group_by(self): - ret = ( - await Book.annotate(avg=Avg("rating")).group_by("author_id").values("author_id", "avg") - ) - - for item in ret: - author_id = item.get("author_id") - avg = item.get("avg") - if author_id == self.a1.pk: - self.assertEqual(avg, 4.5) - elif author_id == self.a2.pk: - self.assertEqual(avg, 2.0) - - async def test_avg_group_by_with_join(self): - ret = ( - await Book.annotate(avg=Avg("rating")) - .group_by("author__name") - .values("author__name", "avg") - ) - self.assertListSortEqual( - ret, - [{"author__name": "author1", "avg": 4.5}, {"author__name": "author2", "avg": 2}], - sorted_key="author__name", - ) - - async def test_avg_filter_group_by(self): - ret = ( - await Book.annotate(avg=Avg("rating")) - .filter(avg__gt=3) - .group_by("author_id") - .values_list("author_id", "avg") - ) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0][1], 4.5) - - async def test_count_values_list_group_by(self): - ret = ( - await Book.annotate(count=Count("id")) - .group_by("author_id") - .values_list("author_id", "count") - ) - - for item in ret: - author_id = item[0] - count = item[1] - if author_id == self.a1.pk: - self.assertEqual(count, 10) - elif author_id == self.a2.pk: - self.assertEqual(count, 5) - - async def test_count_values_list_group_by_with_join(self): - ret = ( - await Book.annotate(count=Count("id")) - .group_by("author__name") - .values_list("author__name", "count") - ) - self.assertListSortEqual(ret, [("author1", 10), ("author2", 5)]) - - async def test_count_values_list_filter_group_by(self): - ret = ( - await Book.annotate(count=Count("id")) - .filter(count__gt=6) - .group_by("author_id") - .values_list("author_id", "count") - ) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0][1], 10) - - async def test_sum_values_list_group_by(self): - ret = ( - await Book.annotate(sum=Sum("rating")) - .group_by("author_id") - .values_list("author_id", "sum") - ) - for item in ret: - author_id = item[0] - sum_ = item[1] - if author_id == self.a1.pk: - self.assertEqual(sum_, 45.0) - elif author_id == self.a2.pk: - self.assertEqual(sum_, 10.0) - - async def test_sum_values_list_group_by_with_join(self): - ret = ( - await Book.annotate(sum=Sum("rating")) - .group_by("author__name") - .values_list("author__name", "sum") - ) - self.assertListSortEqual(ret, [("author1", 45.0), ("author2", 10.0)]) - - async def test_sum_values_list_filter_group_by(self): - ret = ( - await Book.annotate(sum=Sum("rating")) - .filter(sum__gt=11) - .group_by("author_id") - .values_list("author_id", "sum") - ) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0][1], 45.0) - - async def test_avg_values_list_group_by(self): - ret = ( - await Book.annotate(avg=Avg("rating")) - .group_by("author_id") - .values_list("author_id", "avg") - ) - - for item in ret: - author_id = item[0] - avg = item[1] - if author_id == self.a1.pk: - self.assertEqual(avg, 4.5) - elif author_id == self.a2.pk: - self.assertEqual(avg, 2.0) - - async def test_avg_values_list_group_by_with_join(self): - ret = ( - await Book.annotate(avg=Avg("rating")) - .group_by("author__name") - .values_list("author__name", "avg") - ) - self.assertListSortEqual(ret, [("author1", 4.5), ("author2", 2.0)]) - - async def test_avg_values_list_filter_group_by(self): - ret = ( - await Book.annotate(avg=Avg("rating")) - .filter(avg__gt=3) - .group_by("author_id") - .values_list("author_id", "avg") - ) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0][1], 4.5) - - async def test_implicit_group_by(self): - ret = await Author.annotate(count=Count("books")).filter(count__gt=6) - self.assertEqual(ret[0].count, 10) - - async def test_group_by_annotate_result(self): - ret = ( - await Book.annotate(upper_name=Upper("author__name"), count=Count("id")) - .group_by("upper_name") - .values("upper_name", "count") - ) - self.assertListSortEqual( - ret, - [{"upper_name": "AUTHOR1", "count": 10}, {"upper_name": "AUTHOR2", "count": 5}], - sorted_key="upper_name", - ) - - async def test_group_by_requiring_nested_joins(self): - tournament_first = await Tournament.create(name="Tournament 1", desc="d1") - tournament_second = await Tournament.create(name="Tournament 2", desc="d2") - - event_first = await Event.create(name="1", tournament=tournament_first) - event_second = await Event.create(name="2", tournament=tournament_first) - event_third = await Event.create(name="3", tournament=tournament_second) - - team_first = await Team.create(name="First", alias=2) - team_second = await Team.create(name="Second", alias=4) - team_third = await Team.create(name="Third", alias=5) - - await team_first.events.add(event_first) - await team_second.events.add(event_second) - await team_third.events.add(event_third) - - ret = ( - await Tournament.annotate(avg=Avg("events__participants__alias")) - .group_by("desc") - .order_by("desc") - .values("desc", "avg") - ) - self.assertEqual(ret, [{"avg": 3, "desc": "d1"}, {"avg": 5, "desc": "d2"}]) - - async def test_group_by_ambigious_column(self): - tournament_first = await Tournament.create(name="Tournament 1") - tournament_second = await Tournament.create(name="Tournament 2") - - await Event.create(name="1", tournament=tournament_first) - await Event.create(name="2", tournament=tournament_first) - await Event.create(name="3", tournament=tournament_second) - - base_query = ( - Tournament.annotate(event_count=Count("events")).group_by("name").order_by("name") - ) - ret = await base_query.values("name", "event_count") - self.assertEqual( - ret, - [ - {"event_count": 2, "name": "Tournament 1"}, - {"event_count": 1, "name": "Tournament 2"}, - ], - ) - - ret = await base_query.values_list("name", "event_count") - self.assertEqual( - ret, - [("Tournament 1", 2), ("Tournament 2", 1)], - ) - - async def test_group_by_nested_column(self): - tournament_first = await Tournament.create(name="A") - tournament_second = await Tournament.create(name="B") - - await Event.create(name="1", tournament=tournament_first) - await Event.create(name="2", tournament=tournament_first) - await Event.create(name="3", tournament=tournament_first) - await Event.create(name="4", tournament=tournament_second) - - base_query = ( - Event.annotate(count=Count("event_id")) - .group_by("tournament__name") - .order_by("-tournament__name") - ) - ret = await base_query.values("tournament__name", "count") - self.assertEqual( - ret, - [ - {"count": 1, "tournament__name": "B"}, - {"count": 3, "tournament__name": "A"}, - ], - ) - - ret = await base_query.values_list("tournament__name", "count") - self.assertEqual( - ret, - [("B", 1), ("A", 3)], - ) - - async def test_group_by_id_with_nested_filter(self): - ret = await Book.filter(author__name="author1").group_by("id").values_list("id") - self.assertEqual(set(ret), {(book.id,) for book in self.books1}) - - async def test_select_subquery_with_group_by(self): - subquery = Subquery( - Book.all().group_by("rating").order_by("-rating").limit(1).values("rating") - ) - ret = ( - await Author.annotate(top_rating=subquery) - .order_by("id") - .values_list("name", "top_rating") - ) - self.assertEqual(ret, [(self.a1.name, 9.0), (self.a2.name, 9.0)]) +@pytest_asyncio.fixture +async def group_by_data(db): + """Set up Author and Book data for group_by tests.""" + a1 = await Author.create(name="author1") + a2 = await Author.create(name="author2") + books1 = [await Book.create(name=f"book{i}", author=a1, rating=i) for i in range(10)] + books2 = [await Book.create(name=f"book{i}", author=a2, rating=i) for i in range(5)] + return {"a1": a1, "a2": a2, "books1": books1, "books2": books2} + + +@pytest.mark.asyncio +async def test_count_group_by(db, group_by_data): + a1 = group_by_data["a1"] + a2 = group_by_data["a2"] + + ret = await Book.annotate(count=Count("id")).group_by("author_id").values("author_id", "count") + + for item in ret: + author_id = item.get("author_id") + count = item.get("count") + if author_id == a1.pk: + assert count == 10 + elif author_id == a2.pk: + assert count == 5 + + +@pytest.mark.asyncio +async def test_count_group_by_with_join(db, group_by_data): + ret = ( + await Book.annotate(count=Count("id")) + .group_by("author__name") + .values("author__name", "count") + ) + assert sorted(ret, key=lambda x: x["author__name"]) == sorted( + [{"author__name": "author1", "count": 10}, {"author__name": "author2", "count": 5}], + key=lambda x: x["author__name"], + ) + + +@pytest.mark.asyncio +async def test_count_filter_group_by(db, group_by_data): + ret = ( + await Book.annotate(count=Count("id")) + .filter(count__gt=6) + .group_by("author_id") + .values("author_id", "count") + ) + assert len(ret) == 1 + assert ret[0].get("count") == 10 + + +@pytest.mark.asyncio +async def test_sum_group_by(db, group_by_data): + a1 = group_by_data["a1"] + a2 = group_by_data["a2"] + + ret = await Book.annotate(sum=Sum("rating")).group_by("author_id").values("author_id", "sum") + for item in ret: + author_id = item.get("author_id") + sum_ = item.get("sum") + if author_id == a1.pk: + assert sum_ == 45.0 + elif author_id == a2.pk: + assert sum_ == 10.0 + + +@pytest.mark.asyncio +async def test_sum_group_by_with_join(db, group_by_data): + ret = ( + await Book.annotate(sum=Sum("rating")) + .group_by("author__name") + .values("author__name", "sum") + ) + assert sorted(ret, key=lambda x: x["author__name"]) == sorted( + [{"author__name": "author1", "sum": 45.0}, {"author__name": "author2", "sum": 10.0}], + key=lambda x: x["author__name"], + ) + + +@pytest.mark.asyncio +async def test_sum_filter_group_by(db, group_by_data): + ret = ( + await Book.annotate(sum=Sum("rating")) + .filter(sum__gt=11) + .group_by("author_id") + .values("author_id", "sum") + ) + assert len(ret) == 1 + assert ret[0].get("sum") == 45.0 + + +@pytest.mark.asyncio +async def test_avg_group_by(db, group_by_data): + a1 = group_by_data["a1"] + a2 = group_by_data["a2"] + + ret = await Book.annotate(avg=Avg("rating")).group_by("author_id").values("author_id", "avg") + + for item in ret: + author_id = item.get("author_id") + avg = item.get("avg") + if author_id == a1.pk: + assert avg == 4.5 + elif author_id == a2.pk: + assert avg == 2.0 + + +@pytest.mark.asyncio +async def test_avg_group_by_with_join(db, group_by_data): + ret = ( + await Book.annotate(avg=Avg("rating")) + .group_by("author__name") + .values("author__name", "avg") + ) + assert sorted(ret, key=lambda x: x["author__name"]) == sorted( + [{"author__name": "author1", "avg": 4.5}, {"author__name": "author2", "avg": 2}], + key=lambda x: x["author__name"], + ) + + +@pytest.mark.asyncio +async def test_avg_filter_group_by(db, group_by_data): + ret = ( + await Book.annotate(avg=Avg("rating")) + .filter(avg__gt=3) + .group_by("author_id") + .values_list("author_id", "avg") + ) + assert len(ret) == 1 + assert ret[0][1] == 4.5 + + +@pytest.mark.asyncio +async def test_count_values_list_group_by(db, group_by_data): + a1 = group_by_data["a1"] + a2 = group_by_data["a2"] + + ret = ( + await Book.annotate(count=Count("id")) + .group_by("author_id") + .values_list("author_id", "count") + ) + + for item in ret: + author_id = item[0] + count = item[1] + if author_id == a1.pk: + assert count == 10 + elif author_id == a2.pk: + assert count == 5 + + +@pytest.mark.asyncio +async def test_count_values_list_group_by_with_join(db, group_by_data): + ret = ( + await Book.annotate(count=Count("id")) + .group_by("author__name") + .values_list("author__name", "count") + ) + assert sorted(ret) == sorted([("author1", 10), ("author2", 5)]) + + +@pytest.mark.asyncio +async def test_count_values_list_filter_group_by(db, group_by_data): + ret = ( + await Book.annotate(count=Count("id")) + .filter(count__gt=6) + .group_by("author_id") + .values_list("author_id", "count") + ) + assert len(ret) == 1 + assert ret[0][1] == 10 + + +@pytest.mark.asyncio +async def test_sum_values_list_group_by(db, group_by_data): + a1 = group_by_data["a1"] + a2 = group_by_data["a2"] + + ret = ( + await Book.annotate(sum=Sum("rating")).group_by("author_id").values_list("author_id", "sum") + ) + for item in ret: + author_id = item[0] + sum_ = item[1] + if author_id == a1.pk: + assert sum_ == 45.0 + elif author_id == a2.pk: + assert sum_ == 10.0 + + +@pytest.mark.asyncio +async def test_sum_values_list_group_by_with_join(db, group_by_data): + ret = ( + await Book.annotate(sum=Sum("rating")) + .group_by("author__name") + .values_list("author__name", "sum") + ) + assert sorted(ret) == sorted([("author1", 45.0), ("author2", 10.0)]) + + +@pytest.mark.asyncio +async def test_sum_values_list_filter_group_by(db, group_by_data): + ret = ( + await Book.annotate(sum=Sum("rating")) + .filter(sum__gt=11) + .group_by("author_id") + .values_list("author_id", "sum") + ) + assert len(ret) == 1 + assert ret[0][1] == 45.0 + + +@pytest.mark.asyncio +async def test_avg_values_list_group_by(db, group_by_data): + a1 = group_by_data["a1"] + a2 = group_by_data["a2"] + + ret = ( + await Book.annotate(avg=Avg("rating")).group_by("author_id").values_list("author_id", "avg") + ) + + for item in ret: + author_id = item[0] + avg = item[1] + if author_id == a1.pk: + assert avg == 4.5 + elif author_id == a2.pk: + assert avg == 2.0 + + +@pytest.mark.asyncio +async def test_avg_values_list_group_by_with_join(db, group_by_data): + ret = ( + await Book.annotate(avg=Avg("rating")) + .group_by("author__name") + .values_list("author__name", "avg") + ) + assert sorted(ret) == sorted([("author1", 4.5), ("author2", 2.0)]) + + +@pytest.mark.asyncio +async def test_avg_values_list_filter_group_by(db, group_by_data): + ret = ( + await Book.annotate(avg=Avg("rating")) + .filter(avg__gt=3) + .group_by("author_id") + .values_list("author_id", "avg") + ) + assert len(ret) == 1 + assert ret[0][1] == 4.5 + + +@pytest.mark.asyncio +async def test_implicit_group_by(db, group_by_data): + ret = await Author.annotate(count=Count("books")).filter(count__gt=6) + assert ret[0].count == 10 + + +@pytest.mark.asyncio +async def test_group_by_annotate_result(db, group_by_data): + ret = ( + await Book.annotate(upper_name=Upper("author__name"), count=Count("id")) + .group_by("upper_name") + .values("upper_name", "count") + ) + assert sorted(ret, key=lambda x: x["upper_name"]) == sorted( + [{"upper_name": "AUTHOR1", "count": 10}, {"upper_name": "AUTHOR2", "count": 5}], + key=lambda x: x["upper_name"], + ) + + +@pytest.mark.asyncio +async def test_group_by_requiring_nested_joins(db): + tournament_first = await Tournament.create(name="Tournament 1", desc="d1") + tournament_second = await Tournament.create(name="Tournament 2", desc="d2") + + event_first = await Event.create(name="1", tournament=tournament_first) + event_second = await Event.create(name="2", tournament=tournament_first) + event_third = await Event.create(name="3", tournament=tournament_second) + + team_first = await Team.create(name="First", alias=2) + team_second = await Team.create(name="Second", alias=4) + team_third = await Team.create(name="Third", alias=5) + + await team_first.events.add(event_first) + await team_second.events.add(event_second) + await team_third.events.add(event_third) + + ret = ( + await Tournament.annotate(avg=Avg("events__participants__alias")) + .group_by("desc") + .order_by("desc") + .values("desc", "avg") + ) + assert ret == [{"avg": 3, "desc": "d1"}, {"avg": 5, "desc": "d2"}] + + +@pytest.mark.asyncio +async def test_group_by_ambigious_column(db): + tournament_first = await Tournament.create(name="Tournament 1") + tournament_second = await Tournament.create(name="Tournament 2") + + await Event.create(name="1", tournament=tournament_first) + await Event.create(name="2", tournament=tournament_first) + await Event.create(name="3", tournament=tournament_second) + + base_query = Tournament.annotate(event_count=Count("events")).group_by("name").order_by("name") + ret = await base_query.values("name", "event_count") + assert ret == [ + {"event_count": 2, "name": "Tournament 1"}, + {"event_count": 1, "name": "Tournament 2"}, + ] + + ret = await base_query.values_list("name", "event_count") + assert ret == [("Tournament 1", 2), ("Tournament 2", 1)] + + +@pytest.mark.asyncio +async def test_group_by_nested_column(db): + tournament_first = await Tournament.create(name="A") + tournament_second = await Tournament.create(name="B") + + await Event.create(name="1", tournament=tournament_first) + await Event.create(name="2", tournament=tournament_first) + await Event.create(name="3", tournament=tournament_first) + await Event.create(name="4", tournament=tournament_second) + + base_query = ( + Event.annotate(count=Count("event_id")) + .group_by("tournament__name") + .order_by("-tournament__name") + ) + ret = await base_query.values("tournament__name", "count") + assert ret == [ + {"count": 1, "tournament__name": "B"}, + {"count": 3, "tournament__name": "A"}, + ] + + ret = await base_query.values_list("tournament__name", "count") + assert ret == [("B", 1), ("A", 3)] + + +@pytest.mark.asyncio +async def test_group_by_id_with_nested_filter(db, group_by_data): + books1 = group_by_data["books1"] + + ret = await Book.filter(author__name="author1").group_by("id").values_list("id") + assert set(ret) == {(book.id,) for book in books1} + + +@pytest.mark.asyncio +async def test_select_subquery_with_group_by(db, group_by_data): + a1 = group_by_data["a1"] + a2 = group_by_data["a2"] + + subquery = Subquery(Book.all().group_by("rating").order_by("-rating").limit(1).values("rating")) + ret = ( + await Author.annotate(top_rating=subquery).order_by("id").values_list("name", "top_rating") + ) + assert ret == [(a1.name, 9.0), (a2.name, 9.0)] diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py index 905e9c1aa..5224710d6 100644 --- a/tests/test_inheritance.py +++ b/tests/test_inheritance.py @@ -1,15 +1,17 @@ +import pytest + from tests.testmodels import MyAbstractBaseModel, MyDerivedModel -from tortoise.contrib import test -class TestInheritance(test.TestCase): - async def test_basic(self): - model = MyDerivedModel(name="test") - self.assertTrue(hasattr(MyAbstractBaseModel(), "name")) - self.assertTrue(hasattr(model, "created_at")) - self.assertTrue(hasattr(model, "modified_at")) - self.assertTrue(hasattr(model, "name")) - self.assertTrue(hasattr(model, "first_name")) - await model.save() - self.assertIsNotNone(model.created_at) - self.assertIsNotNone(model.modified_at) +@pytest.mark.asyncio +async def test_basic(db): + """Test basic model inheritance with abstract base model.""" + model = MyDerivedModel(name="test") + assert hasattr(MyAbstractBaseModel(), "name") + assert hasattr(model, "created_at") + assert hasattr(model, "modified_at") + assert hasattr(model, "name") + assert hasattr(model, "first_name") + await model.save() + assert model.created_at is not None + assert model.modified_at is not None diff --git a/tests/test_latest_earliest.py b/tests/test_latest_earliest.py index 0c65d38d5..eb2420024 100644 --- a/tests/test_latest_earliest.py +++ b/tests/test_latest_earliest.py @@ -1,51 +1,47 @@ +import pytest +import pytest_asyncio + from tests.testmodels import Event, Tournament -from tortoise.contrib import test - - -class TestLatestEarliest(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - tournament = Tournament(name="Tournament 1") - await tournament.save() - - second_tournament = Tournament(name="Tournament 2") - await second_tournament.save() - - event_first = Event(name="1", tournament=tournament) - await event_first.save() - event_second = Event(name="2", tournament=second_tournament) - await event_second.save() - event_third = Event(name="3", tournament=tournament) - await event_third.save() - event_forth = Event(name="4", tournament=second_tournament) - await event_forth.save() - - async def test_latest(self): - self.assertEqual(await Event.latest("-name"), await Event.get(name="1")) - self.assertEqual(await Event.latest("name"), await Event.get(name="4")) - self.assertEqual(await Event.latest("-name"), await Event.all().order_by("name").first()) - self.assertEqual(await Event.latest("name"), await Event.all().order_by("-name").first()) - self.assertEqual(await Event.latest("tournament__name", "name"), await Event.get(name="4")) - self.assertEqual(await Event.latest("-tournament__name", "name"), await Event.get(name="3")) - self.assertEqual(await Event.latest("tournament__name", "-name"), await Event.get(name="2")) - self.assertEqual( - await Event.latest("-tournament__name", "-name"), await Event.get(name="1") - ) - - async def test_earliest(self): - self.assertEqual(await Event.earliest("name"), await Event.get(name="1")) - self.assertEqual(await Event.earliest("-name"), await Event.get(name="4")) - self.assertEqual(await Event.earliest("name"), await Event.all().order_by("name").first()) - self.assertEqual(await Event.earliest("-name"), await Event.all().order_by("-name").first()) - self.assertEqual( - await Event.earliest("-tournament__name", "-name"), await Event.get(name="4") - ) - self.assertEqual( - await Event.earliest("tournament__name", "-name"), await Event.get(name="3") - ) - self.assertEqual( - await Event.earliest("-tournament__name", "name"), await Event.get(name="2") - ) - self.assertEqual( - await Event.earliest("tournament__name", "name"), await Event.get(name="1") - ) + + +@pytest_asyncio.fixture +async def latest_earliest_data(db): + """Fixture to set up test data for latest/earliest tests.""" + tournament = Tournament(name="Tournament 1") + await tournament.save() + + second_tournament = Tournament(name="Tournament 2") + await second_tournament.save() + + event_first = Event(name="1", tournament=tournament) + await event_first.save() + event_second = Event(name="2", tournament=second_tournament) + await event_second.save() + event_third = Event(name="3", tournament=tournament) + await event_third.save() + event_forth = Event(name="4", tournament=second_tournament) + await event_forth.save() + + +@pytest.mark.asyncio +async def test_latest(latest_earliest_data): + assert await Event.latest("-name") == await Event.get(name="1") + assert await Event.latest("name") == await Event.get(name="4") + assert await Event.latest("-name") == await Event.all().order_by("name").first() + assert await Event.latest("name") == await Event.all().order_by("-name").first() + assert await Event.latest("tournament__name", "name") == await Event.get(name="4") + assert await Event.latest("-tournament__name", "name") == await Event.get(name="3") + assert await Event.latest("tournament__name", "-name") == await Event.get(name="2") + assert await Event.latest("-tournament__name", "-name") == await Event.get(name="1") + + +@pytest.mark.asyncio +async def test_earliest(latest_earliest_data): + assert await Event.earliest("name") == await Event.get(name="1") + assert await Event.earliest("-name") == await Event.get(name="4") + assert await Event.earliest("name") == await Event.all().order_by("name").first() + assert await Event.earliest("-name") == await Event.all().order_by("-name").first() + assert await Event.earliest("-tournament__name", "-name") == await Event.get(name="4") + assert await Event.earliest("tournament__name", "-name") == await Event.get(name="3") + assert await Event.earliest("-tournament__name", "name") == await Event.get(name="2") + assert await Event.earliest("tournament__name", "name") == await Event.get(name="1") diff --git a/tests/test_manager.py b/tests/test_manager.py index c5a5cd900..be817de68 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,19 +1,21 @@ +import pytest + from tests.testmodels import ManagerModel, ManagerModelExtra -from tortoise.contrib import test -class TestManager(test.TestCase): - async def test_manager(self): - m1 = await ManagerModel.create() - m2 = await ManagerModel.create(status=1) +@pytest.mark.asyncio +async def test_manager(db): + """Test custom manager functionality with active status filtering.""" + m1 = await ManagerModel.create() + m2 = await ManagerModel.create(status=1) - self.assertEqual(await ManagerModel.all().active().count(), 1) - self.assertEqual(await ManagerModel.all_objects.count(), 2) + assert await ManagerModel.all().active().count() == 1 + assert await ManagerModel.all_objects.count() == 2 - self.assertIsNone(await ManagerModel.all().active().get_or_none(pk=m1.pk)) - self.assertIsNotNone(await ManagerModel.all_objects.get_or_none(pk=m1.pk)) - self.assertIsNotNone(await ManagerModel.get_or_none(pk=m2.pk)) + assert await ManagerModel.all().active().get_or_none(pk=m1.pk) is None + assert await ManagerModel.all_objects.get_or_none(pk=m1.pk) is not None + assert await ManagerModel.get_or_none(pk=m2.pk) is not None - await ManagerModelExtra.create(extra="extra") - self.assertEqual(await ManagerModelExtra.all_objects.count(), 1) - self.assertEqual(await ManagerModelExtra.all().count(), 1) + await ManagerModelExtra.create(extra="extra") + assert await ManagerModelExtra.all_objects.count() == 1 + assert await ManagerModelExtra.all().count() == 1 diff --git a/tests/test_manual_sql.py b/tests/test_manual_sql.py index bfbf5cc95..2e1d22218 100644 --- a/tests/test_manual_sql.py +++ b/tests/test_manual_sql.py @@ -1,56 +1,65 @@ +import pytest + from tortoise import connections -from tortoise.contrib import test +from tortoise.contrib.test import requireCapability from tortoise.transactions import in_transaction -class TestManualSQL(test.TruncationTestCase): - async def test_simple_insert(self): - conn = connections.get("models") +@pytest.mark.asyncio +async def test_simple_insert(db_truncate): + """Test simple INSERT via raw SQL.""" + conn = connections.get("models") + await conn.execute_query("INSERT INTO author (name) VALUES ('Foo')") + assert await conn.execute_query_dict("SELECT name FROM author") == [{"name": "Foo"}] + + +@pytest.mark.asyncio +async def test_in_transaction(db_truncate): + """Test INSERT inside transaction context manager.""" + async with in_transaction() as conn: await conn.execute_query("INSERT INTO author (name) VALUES ('Foo')") - self.assertEqual( - await conn.execute_query_dict("SELECT name FROM author"), [{"name": "Foo"}] - ) - async def test_in_transaction(self): + conn = connections.get("models") + assert await conn.execute_query_dict("SELECT name FROM author") == [{"name": "Foo"}] + + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_in_transaction_exception(db_truncate): + """Test that transaction rolls back on exception.""" + try: async with in_transaction() as conn: await conn.execute_query("INSERT INTO author (name) VALUES ('Foo')") + raise ValueError("oops") + except ValueError: + pass + + conn = connections.get("models") + assert await conn.execute_query_dict("SELECT name FROM author") == [] - conn = connections.get("models") - self.assertEqual( - await conn.execute_query_dict("SELECT name FROM author"), [{"name": "Foo"}] - ) - - @test.requireCapability(supports_transactions=True) - async def test_in_transaction_exception(self): - try: - async with in_transaction() as conn: - await conn.execute_query("INSERT INTO author (name) VALUES ('Foo')") - raise ValueError("oops") - except ValueError: - pass - - conn = connections.get("models") - self.assertEqual(await conn.execute_query_dict("SELECT name FROM author"), []) - - @test.requireCapability(supports_transactions=True) - async def test_in_transaction_rollback(self): + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_in_transaction_rollback(db_truncate): + """Test explicit rollback inside transaction.""" + async with in_transaction() as conn: + await conn.execute_query("INSERT INTO author (name) VALUES ('Foo')") + await conn.rollback() + + conn = connections.get("models") + assert await conn.execute_query_dict("SELECT name FROM author") == [] + + +@pytest.mark.asyncio +async def test_in_transaction_commit(db_truncate): + """Test explicit commit inside transaction persists data even on exception.""" + try: async with in_transaction() as conn: await conn.execute_query("INSERT INTO author (name) VALUES ('Foo')") - await conn.rollback() - - conn = connections.get("models") - self.assertEqual(await conn.execute_query_dict("SELECT name FROM author"), []) - - async def test_in_transaction_commit(self): - try: - async with in_transaction() as conn: - await conn.execute_query("INSERT INTO author (name) VALUES ('Foo')") - await conn.commit() - raise ValueError("oops") - except ValueError: - pass - - conn = connections.get("models") - self.assertEqual( - await conn.execute_query_dict("SELECT name FROM author"), [{"name": "Foo"}] - ) + await conn.commit() + raise ValueError("oops") + except ValueError: + pass + + conn = connections.get("models") + assert await conn.execute_query_dict("SELECT name FROM author") == [{"name": "Foo"}] diff --git a/tests/test_model_field_name_conflicts.py b/tests/test_model_field_name_conflicts.py index 58d612fd8..6ed916076 100644 --- a/tests/test_model_field_name_conflicts.py +++ b/tests/test_model_field_name_conflicts.py @@ -1,18 +1,18 @@ -import unittest +import pytest from tortoise import fields from tortoise.exceptions import ConfigurationError from tortoise.models import Model -class TestModelFieldNameConflicts(unittest.TestCase): - def test_field_name_conflicts_with_model_attributes(self) -> None: - with self.assertRaises(ConfigurationError) as ctx: +def test_field_name_conflicts_with_model_attributes(): + """Test that using reserved model attribute names as field names raises ConfigurationError.""" + with pytest.raises(ConfigurationError) as exc_info: - class BadModel(Model): - save = fields.IntField() # type: ignore[assignment] - get_table = fields.IntField() # type: ignore[assignment] + class BadModel(Model): + save = fields.IntField() + get_table = fields.IntField() - message = str(ctx.exception) - self.assertIn("save", message) - self.assertIn("get_table", message) + message = str(exc_info.value) + assert "save" in message + assert "get_table" in message diff --git a/tests/test_model_get_table.py b/tests/test_model_get_table.py index ba1019341..a0275bd6a 100644 --- a/tests/test_model_get_table.py +++ b/tests/test_model_get_table.py @@ -1,9 +1,10 @@ from typing import Any, cast +import pytest_asyncio from pypika_tortoise import Table -from tortoise import Tortoise, fields -from tortoise.contrib.test import SimpleTestCase +from tortoise import fields +from tortoise.context import tortoise_test_context from tortoise.models import Model @@ -23,29 +24,27 @@ def _get_table(model: type[Model]) -> Table: return cast(Any, model).get_table() -class TestModelGetTable(SimpleTestCase): - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - await Tortoise.init(db_url="sqlite://:memory:", modules={"models": [__name__]}) - await Tortoise.generate_schemas() +@pytest_asyncio.fixture +async def model_get_table_db(): + """Fixture for model get_table tests with in-memory SQLite.""" + async with tortoise_test_context(modules=[__name__]) as ctx: + yield ctx - async def _tearDownDB(self) -> None: - await Tortoise.get_connection("default").close() - def test_get_table_returns_fresh_table(self) -> None: - table = _get_table(SchemaModel) +def test_get_table_returns_fresh_table(model_get_table_db): + table = _get_table(SchemaModel) - self.assertIsInstance(table, Table) - self.assertEqual(table.get_table_name(), SchemaModel._meta.db_table) - self.assertIsNotNone(table._schema) - assert table._schema is not None - self.assertEqual(table._schema._name, SchemaModel._meta.schema) - self.assertIsNot(table, SchemaModel._meta.basetable) - self.assertIsNot(table, _get_table(SchemaModel)) + assert isinstance(table, Table) + assert table.get_table_name() == SchemaModel._meta.db_table + assert table._schema is not None + assert table._schema._name == SchemaModel._meta.schema + assert table is not SchemaModel._meta.basetable + assert table is not _get_table(SchemaModel) - def test_get_table_default_schema(self) -> None: - table = _get_table(DefaultSchemaModel) - self.assertIsInstance(table, Table) - self.assertEqual(table.get_table_name(), DefaultSchemaModel._meta.db_table) - self.assertIsNone(table._schema) +def test_get_table_default_schema(model_get_table_db): + table = _get_table(DefaultSchemaModel) + + assert isinstance(table, Table) + assert table.get_table_name() == DefaultSchemaModel._meta.db_table + assert table._schema is None diff --git a/tests/test_model_methods.py b/tests/test_model_methods.py index f769593d0..7887af55f 100644 --- a/tests/test_model_methods.py +++ b/tests/test_model_methods.py @@ -1,6 +1,9 @@ import os from uuid import uuid4 +import pytest +import pytest_asyncio + from tests.testmodels import ( Dest_null, Event, @@ -13,7 +16,7 @@ Tournament, UUIDFkRelatedNullModel, ) -from tortoise.contrib import test +from tortoise.contrib.test import requireCapability from tortoise.contrib.test.condition import NotEQ from tortoise.exceptions import ( ConfigurationError, @@ -28,413 +31,836 @@ from tortoise.expressions import F, Q from tortoise.models import NoneAwaitable +# ============================================================================ +# TestModelCreate +# ============================================================================ + + +@pytest.mark.asyncio +async def test_save_generated(db): + mdl = await Tournament.create(name="Test") + mdl2 = await Tournament.get(id=mdl.id) + assert mdl == mdl2 + + +@pytest.mark.asyncio +async def test_save_non_generated(db): + mdl = await UUIDFkRelatedNullModel.create(name="Test") + mdl2 = await UUIDFkRelatedNullModel.get(id=mdl.id) + assert mdl == mdl2 + + +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_save_generated_custom_id(db): + cid = 12345 + mdl = await Tournament.create(id=cid, name="Test") + assert mdl.id == cid + mdl2 = await Tournament.get(id=cid) + assert mdl == mdl2 + + +@pytest.mark.asyncio +async def test_save_non_generated_custom_id(db): + cid = uuid4() + mdl = await UUIDFkRelatedNullModel.create(id=cid, name="Test") + assert mdl.id == cid + mdl2 = await UUIDFkRelatedNullModel.get(id=cid) + assert mdl == mdl2 + + +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_save_generated_duplicate_custom_id(db): + cid = 12345 + await Tournament.create(id=cid, name="TestOriginal") + with pytest.raises(IntegrityError): + await Tournament.create(id=cid, name="Test") + + +@pytest.mark.asyncio +async def test_save_non_generated_duplicate_custom_id(db): + cid = uuid4() + await UUIDFkRelatedNullModel.create(id=cid, name="TestOriginal") + with pytest.raises(IntegrityError): + await UUIDFkRelatedNullModel.create(id=cid, name="Test") + + +@pytest.mark.asyncio +async def test_clone_pk_required_error(db): + mdl = await RequiredPKModel.create(id="A", name="name_a") + with pytest.raises(ParamsError): + mdl.clone() + + +@pytest.mark.asyncio +async def test_clone_pk_required(db): + mdl = await RequiredPKModel.create(id="A", name="name_a") + mdl2 = mdl.clone(pk="B") + await mdl2.save() + mdls = list(await RequiredPKModel.all()) + assert len(mdls) == 2 + -class TestModelCreate(test.TestCase): - async def test_save_generated(self): - mdl = await Tournament.create(name="Test") - mdl2 = await Tournament.get(id=mdl.id) - self.assertEqual(mdl, mdl2) - - async def test_save_non_generated(self): - mdl = await UUIDFkRelatedNullModel.create(name="Test") - mdl2 = await UUIDFkRelatedNullModel.get(id=mdl.id) - self.assertEqual(mdl, mdl2) - - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_save_generated_custom_id(self): - cid = 12345 - mdl = await Tournament.create(id=cid, name="Test") - self.assertEqual(mdl.id, cid) - mdl2 = await Tournament.get(id=cid) - self.assertEqual(mdl, mdl2) - - async def test_save_non_generated_custom_id(self): - cid = uuid4() - mdl = await UUIDFkRelatedNullModel.create(id=cid, name="Test") - self.assertEqual(mdl.id, cid) - mdl2 = await UUIDFkRelatedNullModel.get(id=cid) - self.assertEqual(mdl, mdl2) - - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_save_generated_duplicate_custom_id(self): - cid = 12345 - await Tournament.create(id=cid, name="TestOriginal") - with self.assertRaises(IntegrityError): - await Tournament.create(id=cid, name="Test") - - async def test_save_non_generated_duplicate_custom_id(self): - cid = uuid4() - await UUIDFkRelatedNullModel.create(id=cid, name="TestOriginal") - with self.assertRaises(IntegrityError): - await UUIDFkRelatedNullModel.create(id=cid, name="Test") - - async def test_clone_pk_required_error(self): - mdl = await RequiredPKModel.create(id="A", name="name_a") - with self.assertRaises(ParamsError): - mdl.clone() - - async def test_clone_pk_required(self): - mdl = await RequiredPKModel.create(id="A", name="name_a") - mdl2 = mdl.clone(pk="B") - await mdl2.save() - mdls = list(await RequiredPKModel.all()) - self.assertEqual(len(mdls), 2) - - async def test_implicit_clone_pk_required_none(self): - mdl = await RequiredPKModel.create(id="A", name="name_a") - mdl.pk = None - with self.assertRaises(ValidationError): - await mdl.save() - - -class TestModelMethods(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - self.cls = Tournament - self.mdl = await self.cls.create(name="Test") - self.mdl2 = self.cls(name="Test") - - async def test_save(self): - oldid = self.mdl.id - await self.mdl.save() - self.assertEqual(self.mdl.id, oldid) - - async def test_save_f_expression(self): - int_field = await IntFields.create(intnum=1) - int_field.intnum = F("intnum") + 1 - await int_field.save(update_fields=["intnum"]) - n_int = await IntFields.get(pk=int_field.pk) - self.assertEqual(n_int.intnum, 2) - - async def test_save_full(self): - self.mdl.name = "TestS" - self.mdl.desc = "Something" - await self.mdl.save() - n_mdl = await self.cls.get(id=self.mdl.id) - self.assertEqual(n_mdl.name, "TestS") - self.assertEqual(n_mdl.desc, "Something") - - async def test_save_partial(self): - self.mdl.name = "TestS" - self.mdl.desc = "Something" - await self.mdl.save(update_fields=["desc"]) - n_mdl = await self.cls.get(id=self.mdl.id) - self.assertEqual(n_mdl.name, "Test") - self.assertEqual(n_mdl.desc, "Something") - - async def test_save_partial_with_pk_update(self): - # Not allow to update pk field only - with self.assertRaisesRegex(OperationalError, "Can't update pk field"): - await self.mdl.save(update_fields=["id"]) - # So does update pk field with others - with self.assertRaisesRegex(OperationalError, f"use `{self.cls.__name__}.create` instead"): - await self.mdl.save(update_fields=["id", "desc"]) - - async def test_create(self): - mdl = self.cls(name="Test2") - self.assertIsNone(mdl.id) +@pytest.mark.asyncio +async def test_implicit_clone_pk_required_none(db): + mdl = await RequiredPKModel.create(id="A", name="name_a") + mdl.pk = None + with pytest.raises(ValidationError): await mdl.save() - self.assertIsNotNone(mdl.id) - - async def test_delete(self): - mdl = await self.cls.get(name="Test") - self.assertEqual(self.mdl.id, mdl.id) - - await self.mdl.delete() - - with self.assertRaises(DoesNotExist): - await self.cls.get(name="Test") - - with self.assertRaises(OperationalError): - await self.mdl2.delete() - - def test_str(self): - self.assertEqual(str(self.mdl), "Test") - - def test_repr(self): - self.assertEqual(repr(self.mdl), f"") - self.assertEqual(repr(self.mdl2), "") - - def test_hash(self): - self.assertEqual(hash(self.mdl), self.mdl.id) - with self.assertRaises(TypeError): - hash(self.mdl2) - - async def test_eq(self): - mdl = await self.cls.get(name="Test") - self.assertEqual(self.mdl, mdl) - - async def test_get_or_create(self): - mdl, created = await self.cls.get_or_create(name="Test") - self.assertFalse(created) - self.assertEqual(self.mdl, mdl) - mdl, created = await self.cls.get_or_create(name="Test2") - self.assertTrue(created) - self.assertNotEqual(self.mdl, mdl) - mdl2 = await self.cls.get(name="Test2") - self.assertEqual(mdl, mdl2) - - async def test_update_or_create(self): - mdl, created = await self.cls.update_or_create(name="Test") - self.assertFalse(created) - self.assertEqual(self.mdl, mdl) - mdl, created = await self.cls.update_or_create(name="Test2") - self.assertTrue(created) - self.assertNotEqual(self.mdl, mdl) - mdl2 = await self.cls.get(name="Test2") - self.assertEqual(mdl, mdl2) - - async def test_update_or_create_with_defaults(self): - mdl = await self.cls.get(name=self.mdl.name) - mdl_dict = dict(mdl) - oldid = mdl.id - mdl.id = 135 - with self.assertRaisesRegex(ParamsError, "Conflict value with key='id':"): - # Missing query: check conflict with kwargs and defaults before create - await self.cls.update_or_create(id=mdl.id, defaults=mdl_dict) - desc = str(uuid4()) - # If there is no conflict with defaults and kwargs, it will be success to update or create - defaults = dict(mdl_dict, desc=desc) - kwargs = {"id": defaults["id"], "name": defaults["name"]} - mdl, created = await self.cls.update_or_create(defaults, **kwargs) - self.assertFalse(created) - self.assertEqual(defaults["desc"], mdl.desc) - self.assertNotEqual(self.mdl.desc, mdl.desc) - # Hint query: use defauts to update without checking conflict - mdl2, created = await self.cls.update_or_create( - id=oldid, desc=desc, defaults=dict(mdl_dict, desc="new desc") - ) - self.assertFalse(created) - self.assertNotEqual(dict(mdl), dict(mdl2)) - # Missing query: success to create if no conflict - not_exist_name = str(uuid4()) - no_conflict_defaults = {"name": not_exist_name, "desc": desc} - no_conflict_kwargs = {"name": not_exist_name} - mdl, created = await self.cls.update_or_create(no_conflict_defaults, **no_conflict_kwargs) - self.assertTrue(created) - self.assertEqual(not_exist_name, mdl.name) - - async def test_first(self): - mdl = await self.cls.first() - self.assertEqual(self.mdl.id, mdl.id) - - async def test_last(self): - mdl = await self.cls.last() - self.assertEqual(self.mdl.id, mdl.id) - - async def test_latest(self): - mdl = await self.cls.latest("name") - self.assertEqual(self.mdl.id, mdl.id) - - async def test_earliest(self): - mdl = await self.cls.earliest("name") - self.assertEqual(self.mdl.id, mdl.id) - - async def test_filter(self): - mdl = await self.cls.filter(name="Test").first() - self.assertEqual(self.mdl.id, mdl.id) - mdl = await self.cls.filter(name="Test2").first() - self.assertIsNone(mdl) - - async def test_all(self): - mdls = list(await self.cls.all()) - self.assertEqual(len(mdls), 1) - self.assertEqual(mdls, [self.mdl]) - - async def test_get(self): - mdl = await self.cls.get(name="Test") - self.assertEqual(self.mdl.id, mdl.id) - - with self.assertRaises(DoesNotExist): - await self.cls.get(name="Test2") - - await self.cls.create(name="Test") - - with self.assertRaises(MultipleObjectsReturned): - await self.cls.get(name="Test") - - async def test_exists(self): - await self.cls.create(name="Test") - ret = await self.cls.exists(name="Test") - self.assertTrue(ret) - - ret = await self.cls.exists(name="XXX") - self.assertFalse(ret) - - ret = await self.cls.exists(Q(name="XXX") & Q(name="Test")) - self.assertFalse(ret) - - async def test_get_or_none(self): - mdl = await self.cls.get_or_none(name="Test") - self.assertEqual(self.mdl.id, mdl.id) - - mdl = await self.cls.get_or_none(name="Test2") - self.assertEqual(mdl, None) - - await self.cls.create(name="Test") - - with self.assertRaises(MultipleObjectsReturned): - await self.cls.get_or_none(name="Test") - - @test.skipIf(os.name == "nt", "timestamp issue on Windows") - async def test_update_from_dict(self): - evt1 = await Event.create(name="a", tournament=await Tournament.create(name="a")) - orig_modified = evt1.modified - await evt1.update_from_dict({"alias": "8", "name": "b", "bad_name": "foo"}).save() - self.assertEqual(evt1.alias, 8) - self.assertEqual(evt1.name, "b") - - with self.assertRaises(AttributeError): - _ = evt1.bad_name - - evt2 = await Event.get(name="b") - self.assertEqual(evt1.pk, evt2.pk) - self.assertEqual(evt1.modified, evt2.modified) - self.assertNotEqual(orig_modified, evt1.modified) - - with self.assertRaises(ConfigurationError): - await evt2.update_from_dict({"participants": []}) - - with self.assertRaises(ValueError): - await evt2.update_from_dict({"alias": "foo"}) - - async def test_index_access(self): - obj = await self.cls[self.mdl.pk] - self.assertEqual(obj, self.mdl) - - async def test_index_badval(self): - with self.assertRaises(ObjectDoesNotExistError) as cm: - await self.cls[32767] - the_exception = cm.exception - # For compatibility reasons this should be an instance of KeyError - self.assertIsInstance(the_exception, KeyError) - self.assertIs(the_exception.model, self.cls) - self.assertEqual(the_exception.pk_name, "id") - self.assertEqual(the_exception.pk_val, 32767) - self.assertEqual(str(the_exception), f"{self.cls.__name__} has no object with id=32767") - - async def test_index_badtype(self): - with self.assertRaises(ObjectDoesNotExistError) as cm: - await self.cls["asdf"] - the_exception = cm.exception - # For compatibility reasons this should be an instance of KeyError - self.assertIsInstance(the_exception, KeyError) - self.assertIs(the_exception.model, self.cls) - self.assertEqual(the_exception.pk_name, "id") - self.assertEqual(the_exception.pk_val, "asdf") - self.assertEqual(str(the_exception), f"{self.cls.__name__} has no object with id=asdf") - - async def test_clone(self): - mdl2 = self.mdl.clone() - self.assertEqual(mdl2.pk, None) - await mdl2.save() - self.assertNotEqual(mdl2.pk, self.mdl.pk) - mdls = list(await self.cls.all()) - self.assertEqual(len(mdls), 2) - - async def test_clone_with_pk(self): - mdl2 = self.mdl.clone(pk=8888) - self.assertEqual(mdl2.pk, 8888) - await mdl2.save() - self.assertNotEqual(mdl2.pk, self.mdl.pk) - await mdl2.save() - mdls = list(await self.cls.all()) - self.assertEqual(len(mdls), 2) - - async def test_clone_from_db(self): - mdl2 = await self.cls.get(pk=self.mdl.pk) - mdl3 = mdl2.clone() - mdl3.pk = None - await mdl3.save() - self.assertNotEqual(mdl3.pk, mdl2.pk) - mdls = list(await self.cls.all()) - self.assertEqual(len(mdls), 2) - - async def test_implicit_clone(self): - self.mdl.pk = None - await self.mdl.save() - mdls = list(await self.cls.all()) - self.assertEqual(len(mdls), 2) - - async def test_force_create(self): - obj = self.cls(name="Test", id=self.mdl.id) - with self.assertRaises(IntegrityError): - await obj.save(force_create=True) - - async def test_force_update(self): - obj = self.cls(name="Test3", id=self.mdl.id) + + +# ============================================================================ +# TestModelMethods fixtures and tests +# ============================================================================ + + +@pytest_asyncio.fixture +async def tournament_model(db): + """Fixture that provides a saved Tournament model instance.""" + return await Tournament.create(name="Test") + + +@pytest_asyncio.fixture +async def tournament_model_unsaved(db): + """Fixture that provides an unsaved Tournament model instance.""" + return Tournament(name="Test") + + +@pytest.mark.asyncio +async def test_save(tournament_model): + mdl = tournament_model + oldid = mdl.id + await mdl.save() + assert mdl.id == oldid + + +@pytest.mark.asyncio +async def test_save_f_expression(db): + int_field = await IntFields.create(intnum=1) + int_field.intnum = F("intnum") + 1 + await int_field.save(update_fields=["intnum"]) + n_int = await IntFields.get(pk=int_field.pk) + assert n_int.intnum == 2 + + +@pytest.mark.asyncio +async def test_save_full(tournament_model): + tournament_model.name = "TestS" + tournament_model.desc = "Something" + await tournament_model.save() + n_mdl = await Tournament.get(id=tournament_model.id) + assert n_mdl.name == "TestS" + assert n_mdl.desc == "Something" + + +@pytest.mark.asyncio +async def test_save_partial(tournament_model): + tournament_model.name = "TestS" + tournament_model.desc = "Something" + await tournament_model.save(update_fields=["desc"]) + n_mdl = await Tournament.get(id=tournament_model.id) + assert n_mdl.name == "Test" + assert n_mdl.desc == "Something" + + +@pytest.mark.asyncio +async def test_save_partial_with_pk_update(tournament_model): + # Not allow to update pk field only + with pytest.raises(OperationalError, match="Can't update pk field"): + await tournament_model.save(update_fields=["id"]) + # So does update pk field with others + with pytest.raises(OperationalError, match=f"use `{Tournament.__name__}.create` instead"): + await tournament_model.save(update_fields=["id", "desc"]) + + +@pytest.mark.asyncio +async def test_create(db): + mdl = Tournament(name="Test2") + assert mdl.id is None + await mdl.save() + assert mdl.id is not None + + +@pytest.mark.asyncio +async def test_delete(tournament_model, tournament_model_unsaved): + fetched_mdl = await Tournament.get(name="Test") + assert tournament_model.id == fetched_mdl.id + + await tournament_model.delete() + + with pytest.raises(DoesNotExist): + await Tournament.get(name="Test") + + with pytest.raises(OperationalError): + await tournament_model_unsaved.delete() + + +@pytest.mark.asyncio +async def test_str(tournament_model): + mdl = tournament_model + assert str(mdl) == "Test" + + +@pytest.mark.asyncio +async def test_repr(tournament_model, tournament_model_unsaved): + assert repr(tournament_model) == f"" + assert repr(tournament_model_unsaved) == "" + + +@pytest.mark.asyncio +async def test_hash(tournament_model, tournament_model_unsaved): + assert hash(tournament_model) == tournament_model.id + with pytest.raises(TypeError): + hash(tournament_model_unsaved) + + +@pytest.mark.asyncio +async def test_eq(tournament_model): + mdl = tournament_model + fetched_mdl = await Tournament.get(name="Test") + assert mdl == fetched_mdl + + +@pytest.mark.asyncio +async def test_get_or_create(tournament_model): + mdl = tournament_model + fetched_mdl, created = await Tournament.get_or_create(name="Test") + assert created is False + assert mdl == fetched_mdl + new_mdl, created = await Tournament.get_or_create(name="Test2") + assert created is True + assert mdl != new_mdl + mdl2 = await Tournament.get(name="Test2") + assert new_mdl == mdl2 + + +@pytest.mark.asyncio +async def test_update_or_create(tournament_model): + mdl = tournament_model + fetched_mdl, created = await Tournament.update_or_create(name="Test") + assert created is False + assert mdl == fetched_mdl + new_mdl, created = await Tournament.update_or_create(name="Test2") + assert created is True + assert mdl != new_mdl + mdl2 = await Tournament.get(name="Test2") + assert new_mdl == mdl2 + + +@pytest.mark.asyncio +async def test_update_or_create_with_defaults(tournament_model): + mdl = tournament_model + fetched_mdl = await Tournament.get(name=mdl.name) + mdl_dict = dict(fetched_mdl) + oldid = fetched_mdl.id + fetched_mdl.id = 135 + with pytest.raises(ParamsError, match="Conflict value with key='id':"): + # Missing query: check conflict with kwargs and defaults before create + await Tournament.update_or_create(id=fetched_mdl.id, defaults=mdl_dict) + desc = str(uuid4()) + # If there is no conflict with defaults and kwargs, it will be success to update or create + defaults = dict(mdl_dict, desc=desc) + kwargs = {"id": defaults["id"], "name": defaults["name"]} + updated_mdl, created = await Tournament.update_or_create(defaults, **kwargs) + assert created is False + assert defaults["desc"] == updated_mdl.desc + assert mdl.desc != updated_mdl.desc + # Hint query: use defauts to update without checking conflict + mdl2, created = await Tournament.update_or_create( + id=oldid, desc=desc, defaults=dict(mdl_dict, desc="new desc") + ) + assert created is False + assert dict(updated_mdl) != dict(mdl2) + # Missing query: success to create if no conflict + not_exist_name = str(uuid4()) + no_conflict_defaults = {"name": not_exist_name, "desc": desc} + no_conflict_kwargs = {"name": not_exist_name} + created_mdl, created = await Tournament.update_or_create( + no_conflict_defaults, **no_conflict_kwargs + ) + assert created is True + assert not_exist_name == created_mdl.name + + +@pytest.mark.asyncio +async def test_first(tournament_model): + mdl = tournament_model + fetched_mdl = await Tournament.first() + assert mdl.id == fetched_mdl.id + + +@pytest.mark.asyncio +async def test_last(tournament_model): + mdl = tournament_model + fetched_mdl = await Tournament.last() + assert mdl.id == fetched_mdl.id + + +@pytest.mark.asyncio +async def test_latest(tournament_model): + mdl = tournament_model + fetched_mdl = await Tournament.latest("name") + assert mdl.id == fetched_mdl.id + + +@pytest.mark.asyncio +async def test_earliest(tournament_model): + mdl = tournament_model + fetched_mdl = await Tournament.earliest("name") + assert mdl.id == fetched_mdl.id + + +@pytest.mark.asyncio +async def test_filter(tournament_model): + mdl = tournament_model + fetched_mdl = await Tournament.filter(name="Test").first() + assert mdl.id == fetched_mdl.id + fetched_mdl = await Tournament.filter(name="Test2").first() + assert fetched_mdl is None + + +@pytest.mark.asyncio +async def test_all(tournament_model): + mdl = tournament_model + mdls = list(await Tournament.all()) + assert len(mdls) == 1 + assert mdls == [mdl] + + +@pytest.mark.asyncio +async def test_get(tournament_model): + mdl = tournament_model + fetched_mdl = await Tournament.get(name="Test") + assert mdl.id == fetched_mdl.id + + with pytest.raises(DoesNotExist): + await Tournament.get(name="Test2") + + await Tournament.create(name="Test") + + with pytest.raises(MultipleObjectsReturned): + await Tournament.get(name="Test") + + +@pytest.mark.asyncio +async def test_exists(db): + await Tournament.create(name="Test") + ret = await Tournament.exists(name="Test") + assert ret is True + + ret = await Tournament.exists(name="XXX") + assert ret is False + + ret = await Tournament.exists(Q(name="XXX") & Q(name="Test")) + assert ret is False + + +@pytest.mark.asyncio +async def test_get_or_none(tournament_model): + mdl = tournament_model + fetched_mdl = await Tournament.get_or_none(name="Test") + assert mdl.id == fetched_mdl.id + + fetched_mdl = await Tournament.get_or_none(name="Test2") + assert fetched_mdl is None + + await Tournament.create(name="Test") + + with pytest.raises(MultipleObjectsReturned): + await Tournament.get_or_none(name="Test") + + +@pytest.mark.skipif(os.name == "nt", reason="timestamp issue on Windows") +@pytest.mark.asyncio +async def test_update_from_dict(db): + evt1 = await Event.create(name="a", tournament=await Tournament.create(name="a")) + orig_modified = evt1.modified + await evt1.update_from_dict({"alias": "8", "name": "b", "bad_name": "foo"}).save() + assert evt1.alias == 8 + assert evt1.name == "b" + + with pytest.raises(AttributeError): + _ = evt1.bad_name + + evt2 = await Event.get(name="b") + assert evt1.pk == evt2.pk + assert evt1.modified == evt2.modified + assert orig_modified != evt1.modified + + with pytest.raises(ConfigurationError): + await evt2.update_from_dict({"participants": []}) + + with pytest.raises(ValueError): + await evt2.update_from_dict({"alias": "foo"}) + + +@pytest.mark.asyncio +async def test_index_access(tournament_model): + obj = await Tournament[tournament_model.pk] + assert obj == tournament_model + + +@pytest.mark.asyncio +async def test_index_badval(db): + with pytest.raises(ObjectDoesNotExistError) as exc_info: + await Tournament[32767] + the_exception = exc_info.value + # For compatibility reasons this should be an instance of KeyError + assert isinstance(the_exception, KeyError) + assert the_exception.model is Tournament + assert the_exception.pk_name == "id" + assert the_exception.pk_val == 32767 + assert str(the_exception) == f"{Tournament.__name__} has no object with id=32767" + + +@pytest.mark.asyncio +async def test_index_badtype(db): + with pytest.raises(ObjectDoesNotExistError) as exc_info: + await Tournament["asdf"] + the_exception = exc_info.value + # For compatibility reasons this should be an instance of KeyError + assert isinstance(the_exception, KeyError) + assert the_exception.model is Tournament + assert the_exception.pk_name == "id" + assert the_exception.pk_val == "asdf" + assert str(the_exception) == f"{Tournament.__name__} has no object with id=asdf" + + +@pytest.mark.asyncio +async def test_clone(tournament_model): + mdl = tournament_model + mdl2 = mdl.clone() + assert mdl2.pk is None + await mdl2.save() + assert mdl2.pk != mdl.pk + mdls = list(await Tournament.all()) + assert len(mdls) == 2 + + +@pytest.mark.asyncio +async def test_clone_with_pk(tournament_model): + mdl = tournament_model + mdl2 = mdl.clone(pk=8888) + assert mdl2.pk == 8888 + await mdl2.save() + assert mdl2.pk != mdl.pk + await mdl2.save() + mdls = list(await Tournament.all()) + assert len(mdls) == 2 + + +@pytest.mark.asyncio +async def test_clone_from_db(tournament_model): + mdl = tournament_model + mdl2 = await Tournament.get(pk=mdl.pk) + mdl3 = mdl2.clone() + mdl3.pk = None + await mdl3.save() + assert mdl3.pk != mdl2.pk + mdls = list(await Tournament.all()) + assert len(mdls) == 2 + + +@pytest.mark.asyncio +async def test_implicit_clone(tournament_model): + mdl = tournament_model + mdl.pk = None + await mdl.save() + mdls = list(await Tournament.all()) + assert len(mdls) == 2 + + +@pytest.mark.asyncio +async def test_force_create(tournament_model): + obj = Tournament(name="Test", id=tournament_model.id) + with pytest.raises(IntegrityError): + await obj.save(force_create=True) + + +@pytest.mark.asyncio +async def test_force_update(tournament_model): + obj = Tournament(name="Test3", id=tournament_model.id) + await obj.save(force_update=True) + + +@pytest.mark.asyncio +async def test_force_update_raise(tournament_model): + obj = Tournament(name="Test3", id=tournament_model.id + 100) + with pytest.raises(IntegrityError): await obj.save(force_update=True) - async def test_force_update_raise(self): - obj = self.cls(name="Test3", id=self.mdl.id + 100) - with self.assertRaises(IntegrityError): - await obj.save(force_update=True) - - async def test_raw(self): - await Node.create(name="TestRaw") - ret = await Node.raw("select * from node where name='TestRaw'") - self.assertEqual(len(ret), 1) - ret = await Node.raw("select * from node where name='111'") - self.assertEqual(len(ret), 0) - - -class TestModelMethodsNoID(TestModelMethods): - async def asyncSetUp(self): - await super().asyncSetUp() - self.mdl = await NoID.create(name="Test") - self.mdl2 = NoID(name="Test") - self.cls = NoID - - def test_str(self): - self.assertEqual(str(self.mdl), "") - - def test_repr(self): - self.assertEqual(repr(self.mdl), f"") - self.assertEqual(repr(self.mdl2), "") - - -class TestModelConstructor(test.TestCase): - def test_null_in_nonnull_field(self): - with self.assertRaisesRegex(ValueError, "name is non nullable field, but null was passed"): - Event(name=None) - - def test_rev_fk(self): - with self.assertRaisesRegex( - ConfigurationError, - "You can't set backward relations through init, change related model instead", - ): - Tournament(name="a", events=[]) - - def test_m2m(self): - with self.assertRaisesRegex( - ConfigurationError, "You can't set m2m relations through init, use m2m_manager instead" - ): - Event(name="a", participants=[]) - - def test_rev_m2m(self): - with self.assertRaisesRegex( - ConfigurationError, "You can't set m2m relations through init, use m2m_manager instead" - ): - Team(name="a", events=[]) - - async def test_rev_o2o(self): - with self.assertRaisesRegex( - ConfigurationError, - "You can't set backward one to one relations through init, " - "change related model instead", - ): - address = await O2O_null.create(name="Ocean") - await Dest_null(name="a", address_null=address) - - def test_fk_unsaved(self): - with self.assertRaisesRegex(OperationalError, "You should first call .save()"): - Event(name="a", tournament=Tournament(name="a")) - - async def test_fk_saved(self): - await Event.create(name="a", tournament=await Tournament.create(name="a")) - - async def test_noneawaitable(self): - self.assertFalse(NoneAwaitable) - self.assertIsNone(await NoneAwaitable) - self.assertFalse(NoneAwaitable) - self.assertIsNone(await NoneAwaitable) + +@pytest.mark.asyncio +async def test_raw(db): + await Node.create(name="TestRaw") + ret = await Node.raw("select * from node where name='TestRaw'") + assert len(ret) == 1 + ret = await Node.raw("select * from node where name='111'") + assert len(ret) == 0 + + +# ============================================================================ +# TestModelMethodsNoID fixtures and tests +# ============================================================================ + + +@pytest_asyncio.fixture +async def noid_model(db): + """Fixture that provides a saved NoID model instance.""" + return await NoID.create(name="Test") + + +@pytest_asyncio.fixture +async def noid_model_unsaved(db): + """Fixture that provides an unsaved NoID model instance.""" + return NoID(name="Test") + + +@pytest.mark.asyncio +async def test_noid_save(noid_model): + oldid = noid_model.id + await noid_model.save() + assert noid_model.id == oldid + + +@pytest.mark.asyncio +async def test_noid_save_f_expression(db): + int_field = await IntFields.create(intnum=1) + int_field.intnum = F("intnum") + 1 + await int_field.save(update_fields=["intnum"]) + n_int = await IntFields.get(pk=int_field.pk) + assert n_int.intnum == 2 + + +@pytest.mark.asyncio +async def test_noid_save_full(noid_model): + noid_model.name = "TestS" + await noid_model.save() + n_mdl = await NoID.get(id=noid_model.id) + assert n_mdl.name == "TestS" + + +@pytest.mark.asyncio +async def test_noid_save_partial(noid_model): + noid_model.name = "TestS" + noid_model.desc = "Something" + await noid_model.save(update_fields=["desc"]) + n_mdl = await NoID.get(id=noid_model.id) + assert n_mdl.name == "Test" # name should not be updated + assert n_mdl.desc == "Something" # desc should be updated + + +@pytest.mark.asyncio +async def test_noid_save_partial_with_pk_update(noid_model): + # Not allow to update pk field only + with pytest.raises(OperationalError, match="Can't update pk field"): + await noid_model.save(update_fields=["id"]) + # So does update pk field with others + with pytest.raises(OperationalError, match=f"use `{NoID.__name__}.create` instead"): + await noid_model.save(update_fields=["id", "name"]) + + +@pytest.mark.asyncio +async def test_noid_create(db): + mdl = NoID(name="Test2") + assert mdl.id is None + await mdl.save() + assert mdl.id is not None + + +@pytest.mark.asyncio +async def test_noid_delete(noid_model, noid_model_unsaved): + fetched_mdl = await NoID.get(name="Test") + assert noid_model.id == fetched_mdl.id + + await noid_model.delete() + + with pytest.raises(DoesNotExist): + await NoID.get(name="Test") + + with pytest.raises(OperationalError): + await noid_model_unsaved.delete() + + +@pytest.mark.asyncio +async def test_noid_str(noid_model): + assert str(noid_model) == "" + + +@pytest.mark.asyncio +async def test_noid_repr(noid_model, noid_model_unsaved): + assert repr(noid_model) == f"" + assert repr(noid_model_unsaved) == "" + + +@pytest.mark.asyncio +async def test_noid_hash(noid_model, noid_model_unsaved): + assert hash(noid_model) == noid_model.id + with pytest.raises(TypeError): + hash(noid_model_unsaved) + + +@pytest.mark.asyncio +async def test_noid_eq(noid_model): + fetched_mdl = await NoID.get(name="Test") + assert noid_model == fetched_mdl + + +@pytest.mark.asyncio +async def test_noid_get_or_create(noid_model): + fetched_mdl, created = await NoID.get_or_create(name="Test") + assert created is False + assert noid_model == fetched_mdl + new_mdl, created = await NoID.get_or_create(name="Test2") + assert created is True + assert noid_model != new_mdl + mdl2 = await NoID.get(name="Test2") + assert new_mdl == mdl2 + + +@pytest.mark.asyncio +async def test_noid_update_or_create(noid_model): + fetched_mdl, created = await NoID.update_or_create(name="Test") + assert created is False + assert noid_model == fetched_mdl + new_mdl, created = await NoID.update_or_create(name="Test2") + assert created is True + assert noid_model != new_mdl + mdl2 = await NoID.get(name="Test2") + assert new_mdl == mdl2 + + +@pytest.mark.asyncio +async def test_noid_first(noid_model): + fetched_mdl = await NoID.first() + assert noid_model.id == fetched_mdl.id + + +@pytest.mark.asyncio +async def test_noid_last(noid_model): + fetched_mdl = await NoID.last() + assert noid_model.id == fetched_mdl.id + + +@pytest.mark.asyncio +async def test_noid_latest(noid_model): + fetched_mdl = await NoID.latest("name") + assert noid_model.id == fetched_mdl.id + + +@pytest.mark.asyncio +async def test_noid_earliest(noid_model): + fetched_mdl = await NoID.earliest("name") + assert noid_model.id == fetched_mdl.id + + +@pytest.mark.asyncio +async def test_noid_filter(noid_model): + fetched_mdl = await NoID.filter(name="Test").first() + assert noid_model.id == fetched_mdl.id + fetched_mdl = await NoID.filter(name="Test2").first() + assert fetched_mdl is None + + +@pytest.mark.asyncio +async def test_noid_all(noid_model): + mdls = list(await NoID.all()) + assert len(mdls) == 1 + assert mdls == [noid_model] + + +@pytest.mark.asyncio +async def test_noid_get(noid_model): + fetched_mdl = await NoID.get(name="Test") + assert noid_model.id == fetched_mdl.id + + with pytest.raises(DoesNotExist): + await NoID.get(name="Test2") + + await NoID.create(name="Test") + + with pytest.raises(MultipleObjectsReturned): + await NoID.get(name="Test") + + +@pytest.mark.asyncio +async def test_noid_exists(noid_model): + await NoID.create(name="Test") + ret = await NoID.exists(name="Test") + assert ret is True + + ret = await NoID.exists(name="XXX") + assert ret is False + + ret = await NoID.exists(Q(name="XXX") & Q(name="Test")) + assert ret is False + + +@pytest.mark.asyncio +async def test_noid_get_or_none(noid_model): + fetched_mdl = await NoID.get_or_none(name="Test") + assert noid_model.id == fetched_mdl.id + + fetched_mdl = await NoID.get_or_none(name="Test2") + assert fetched_mdl is None + + await NoID.create(name="Test") + + with pytest.raises(MultipleObjectsReturned): + await NoID.get_or_none(name="Test") + + +@pytest.mark.asyncio +async def test_noid_index_access(noid_model): + obj = await NoID[noid_model.pk] + assert obj == noid_model + + +@pytest.mark.asyncio +async def test_noid_index_badval(db): + with pytest.raises(ObjectDoesNotExistError) as exc_info: + await NoID[32767] + the_exception = exc_info.value + # For compatibility reasons this should be an instance of KeyError + assert isinstance(the_exception, KeyError) + assert the_exception.model is NoID + assert the_exception.pk_name == "id" + assert the_exception.pk_val == 32767 + assert str(the_exception) == f"{NoID.__name__} has no object with id=32767" + + +@pytest.mark.asyncio +async def test_noid_index_badtype(db): + with pytest.raises(ObjectDoesNotExistError) as exc_info: + await NoID["asdf"] + the_exception = exc_info.value + # For compatibility reasons this should be an instance of KeyError + assert isinstance(the_exception, KeyError) + assert the_exception.model is NoID + assert the_exception.pk_name == "id" + assert the_exception.pk_val == "asdf" + assert str(the_exception) == f"{NoID.__name__} has no object with id=asdf" + + +@pytest.mark.asyncio +async def test_noid_clone(noid_model): + mdl2 = noid_model.clone() + assert mdl2.pk is None + await mdl2.save() + assert mdl2.pk != noid_model.pk + mdls = list(await NoID.all()) + assert len(mdls) == 2 + + +@pytest.mark.asyncio +async def test_noid_clone_with_pk(noid_model): + mdl2 = noid_model.clone(pk=8888) + assert mdl2.pk == 8888 + await mdl2.save() + assert mdl2.pk != noid_model.pk + await mdl2.save() + mdls = list(await NoID.all()) + assert len(mdls) == 2 + + +@pytest.mark.asyncio +async def test_noid_clone_from_db(noid_model): + mdl2 = await NoID.get(pk=noid_model.pk) + mdl3 = mdl2.clone() + mdl3.pk = None + await mdl3.save() + assert mdl3.pk != mdl2.pk + mdls = list(await NoID.all()) + assert len(mdls) == 2 + + +@pytest.mark.asyncio +async def test_noid_implicit_clone(noid_model): + noid_model.pk = None + await noid_model.save() + mdls = list(await NoID.all()) + assert len(mdls) == 2 + + +@pytest.mark.asyncio +async def test_noid_force_create(noid_model): + obj = NoID(name="Test", id=noid_model.id) + with pytest.raises(IntegrityError): + await obj.save(force_create=True) + + +@pytest.mark.asyncio +async def test_noid_force_update(noid_model): + obj = NoID(name="Test3", id=noid_model.id) + await obj.save(force_update=True) + + +@pytest.mark.asyncio +async def test_noid_force_update_raise(noid_model): + obj = NoID(name="Test3", id=noid_model.id + 100) + with pytest.raises(IntegrityError): + await obj.save(force_update=True) + + +# ============================================================================ +# TestModelConstructor +# ============================================================================ + + +def test_null_in_nonnull_field(): + with pytest.raises(ValueError, match="name is non nullable field, but null was passed"): + Event(name=None) + + +def test_rev_fk(): + with pytest.raises( + ConfigurationError, + match="You can't set backward relations through init, change related model instead", + ): + Tournament(name="a", events=[]) + + +def test_m2m(): + with pytest.raises( + ConfigurationError, + match="You can't set m2m relations through init, use m2m_manager instead", + ): + Event(name="a", participants=[]) + + +def test_rev_m2m(): + with pytest.raises( + ConfigurationError, + match="You can't set m2m relations through init, use m2m_manager instead", + ): + Team(name="a", events=[]) + + +@pytest.mark.asyncio +async def test_rev_o2o(db): + with pytest.raises( + ConfigurationError, + match="You can't set backward one to one relations through init, " + "change related model instead", + ): + address = await O2O_null.create(name="Ocean") + await Dest_null(name="a", address_null=address) + + +def test_fk_unsaved(): + with pytest.raises(OperationalError, match="You should first call .save()"): + Event(name="a", tournament=Tournament(name="a")) + + +@pytest.mark.asyncio +async def test_fk_saved(db): + await Event.create(name="a", tournament=await Tournament.create(name="a")) + + +@pytest.mark.asyncio +async def test_noneawaitable(db): + assert not NoneAwaitable + assert await NoneAwaitable is None + assert not NoneAwaitable + assert await NoneAwaitable is None diff --git a/tests/test_only.py b/tests/test_only.py index 2a5eb6e28..c2b3c8fb3 100644 --- a/tests/test_only.py +++ b/tests/test_only.py @@ -1,294 +1,417 @@ +import pytest +import pytest_asyncio + from tests.testmodels import DoubleFK, Event, SourceFields, StraightFields, Tournament -from tortoise.contrib import test from tortoise.exceptions import FieldError, IncompleteInstanceError from tortoise.functions import Count +# ============================================================================ +# Fixtures for TestOnlyStraight and TestOnlySource +# ============================================================================ + + +@pytest_asyncio.fixture +async def straight_fields_instance(db): + """Create a StraightFields instance for testing.""" + return await StraightFields.create(chars="Test") + + +@pytest_asyncio.fixture +async def source_fields_instance(db): + """Create a SourceFields instance for testing.""" + return await SourceFields.create(chars="Test") -class TestOnlyStraight(test.TestCase): - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - self.model = StraightFields - self.instance = await self.model.create(chars="Test") - async def test_get(self): - instance_part = await self.model.get(chars="Test").only("chars", "blip") +# ============================================================================ +# TestOnlyStraight tests +# ============================================================================ - self.assertEqual(instance_part.chars, "Test") - with self.assertRaises(AttributeError): - _ = instance_part.nullable - async def test_filter(self): - instances = await self.model.filter(chars="Test").only("chars", "blip") +@pytest.mark.asyncio +async def test_only_straight_get(db, straight_fields_instance): + instance_part = await StraightFields.get(chars="Test").only("chars", "blip") - self.assertEqual(len(instances), 1) - self.assertEqual(instances[0].chars, "Test") - with self.assertRaises(AttributeError): - _ = instances[0].nullable + assert instance_part.chars == "Test" + with pytest.raises(AttributeError): + _ = instance_part.nullable - async def test_first(self): - instance_part = await self.model.filter(chars="Test").only("chars", "blip").first() - self.assertEqual(instance_part.chars, "Test") - with self.assertRaises(AttributeError): - _ = instance_part.nullable +@pytest.mark.asyncio +async def test_only_straight_filter(db, straight_fields_instance): + instances = await StraightFields.filter(chars="Test").only("chars", "blip") - async def test_save(self): - instance_part = await self.model.get(chars="Test").only("chars", "blip") + assert len(instances) == 1 + assert instances[0].chars == "Test" + with pytest.raises(AttributeError): + _ = instances[0].nullable - with self.assertRaisesRegex(IncompleteInstanceError, " is a partial model"): - await instance_part.save() - async def test_partial_save(self): - instance_part = await self.model.get(chars="Test").only("chars", "blip") +@pytest.mark.asyncio +async def test_only_straight_first(db, straight_fields_instance): + instance_part = await StraightFields.filter(chars="Test").only("chars", "blip").first() - with self.assertRaisesRegex(IncompleteInstanceError, "Partial update not available"): - await instance_part.save(update_fields=["chars"]) + assert instance_part.chars == "Test" + with pytest.raises(AttributeError): + _ = instance_part.nullable - async def test_partial_save_with_pk_wrong_field(self): - instance_part = await self.model.get(chars="Test").only("chars", "eyedee") - with self.assertRaisesRegex(IncompleteInstanceError, "field 'nullable' is not available"): - await instance_part.save(update_fields=["nullable"]) +@pytest.mark.asyncio +async def test_only_straight_save(db, straight_fields_instance): + instance_part = await StraightFields.get(chars="Test").only("chars", "blip") - async def test_partial_save_with_pk(self): - instance_part = await self.model.get(chars="Test").only("chars", "eyedee") + with pytest.raises(IncompleteInstanceError, match=" is a partial model"): + await instance_part.save() - instance_part.chars = "Test1" + +@pytest.mark.asyncio +async def test_only_straight_partial_save(db, straight_fields_instance): + instance_part = await StraightFields.get(chars="Test").only("chars", "blip") + + with pytest.raises(IncompleteInstanceError, match="Partial update not available"): await instance_part.save(update_fields=["chars"]) - instance2 = await self.model.get(pk=self.instance.pk) - self.assertEqual(instance2.chars, "Test1") - - -class TestOnlySource(TestOnlyStraight): - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - self.model = SourceFields # type: ignore - self.instance = await self.model.create(chars="Test") - - -class TestOnlyRecursive(test.TestCase): - async def test_one_level(self): - left_1st_lvl = await DoubleFK.create(name="1st") - root = await DoubleFK.create(name="root", left=left_1st_lvl) - - ret = ( - await DoubleFK.filter(pk=root.pk).only("name", "left__name", "left__left__name").first() - ) - self.assertIsNotNone(ret) - with self.assertRaises(AttributeError): - _ = ret.id - self.assertEqual(ret.name, "root") - self.assertEqual(ret.left.name, "1st") - with self.assertRaises(AttributeError): - _ = ret.left.id - with self.assertRaises(AttributeError): - _ = ret.right - - async def test_two_levels(self): - left_2nd_lvl = await DoubleFK.create(name="second leaf") - left_1st_lvl = await DoubleFK.create(name="1st", left=left_2nd_lvl) - root = await DoubleFK.create(name="root", left=left_1st_lvl) - - ret = ( - await DoubleFK.filter(pk=root.pk).only("name", "left__name", "left__left__name").first() - ) - self.assertIsNotNone(ret) - with self.assertRaises(AttributeError): - _ = ret.id - self.assertEqual(ret.name, "root") - self.assertEqual(ret.left.name, "1st") - with self.assertRaises(AttributeError): - _ = ret.left.id - self.assertEqual(ret.left.left.name, "second leaf") - - async def test_two_levels_reverse_argument_order(self): - left_2nd_lvl = await DoubleFK.create(name="second leaf") - left_1st_lvl = await DoubleFK.create(name="1st", left=left_2nd_lvl) - root = await DoubleFK.create(name="root", left=left_1st_lvl) - - ret = ( - await DoubleFK.filter(pk=root.pk).only("left__left__name", "left__name", "name").first() - ) - self.assertIsNotNone(ret) - with self.assertRaises(AttributeError): - _ = ret.id - self.assertEqual(ret.name, "root") - self.assertEqual(ret.left.name, "1st") - with self.assertRaises(AttributeError): - _ = ret.left.id - self.assertEqual(ret.left.left.name, "second leaf") - - -class TestOnlyRelated(test.TestCase): - async def test_related_one_level(self): - tournament = await Tournament.create(name="New Tournament", desc="New Description") - await Event.create(name="Event 1", tournament=tournament) - await Event.create(name="Event 2", tournament=tournament) - - ret = ( - await Event.filter(tournament=tournament) - .only("name", "tournament__name") - .order_by("name") - ) - self.assertEqual(len(ret), 2) - self.assertEqual(ret[0].name, "Event 1") - with self.assertRaises(AttributeError): - _ = ret[0].alias - self.assertEqual(ret[1].name, "Event 2") - with self.assertRaises(AttributeError): - _ = ret[1].alias - self.assertEqual(ret[0].tournament.name, "New Tournament") - with self.assertRaises(AttributeError): - _ = ret[0].tournament.id - with self.assertRaises(AttributeError): - _ = ret[0].tournament.desc - - async def test_related_one_level_reversed_argument_order(self): - tournament = await Tournament.create(name="New Tournament", desc="New Description") - await Event.create(name="Event 1", tournament=tournament) - await Event.create(name="Event 2", tournament=tournament) - - ret = ( - await Event.filter(tournament=tournament) - .only("tournament__name", "name") - .order_by("name") - ) - self.assertEqual(len(ret), 2) - self.assertEqual(ret[0].name, "Event 1") - self.assertEqual(ret[0].tournament.name, "New Tournament") - - async def test_just_related(self): - tournament = await Tournament.create(name="New Tournament", desc="New Description") - await Event.create(name="Event 1", tournament=tournament) - await Event.create(name="Event 2", tournament=tournament) - - ret = ( - await Event.filter(tournament=tournament) - .only("tournament__name") - .order_by("name") - .all() - ) - self.assertEqual(len(ret), 2) - self.assertEqual(ret[0].tournament.name, "New Tournament") - self.assertEqual(ret[1].tournament.name, "New Tournament") - - -class TestOnlyAdvanced(test.TestCase): - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - self.tournament = await Tournament.create(name="Tournament A", desc="Description A") - self.event1 = await Event.create(name="Event 1", tournament=self.tournament) - self.event2 = await Event.create(name="Event 2", tournament=self.tournament) - - async def test_exclude(self): - """Test .only() combined with .exclude()""" - events = await Event.filter(tournament=self.tournament).exclude(name="Event 2").only("name") - self.assertEqual(len(events), 1) - self.assertEqual(events[0].name, "Event 1") - with self.assertRaises(AttributeError): - _ = events[0].modified - - async def test_limit(self): - """Test .only() combined with .limit()""" - events = await Event.all().only("name").limit(1) - self.assertEqual(len(events), 1) - self.assertEqual(events[0].name, "Event 1") # Assumes ordering by PK - with self.assertRaises(AttributeError): - _ = events[0].modified - - async def test_distinct(self): - """Test .only() combined with .distinct()""" - # Create duplicate event names - await Event.create(name="Event 1", tournament=self.tournament) - - events = await Event.all().only("name").distinct() - # Should only have 2 distinct event names - self.assertEqual(len(events), 2) - event_names = {e.name for e in events} - self.assertEqual(event_names, {"Event 1", "Event 2"}) - - async def test_values(self): - """Test .only() combined with .values()""" - with self.assertRaises(ValueError, msg="values() cannot be used with .only()"): - await Event.all().only("name").values("name") - - async def test_pk_field(self): - """Test .only() with just the primary key field""" - tournament = await Tournament.first().only("id") - self.assertIsNotNone(tournament.id) - with self.assertRaises(AttributeError): - _ = tournament.name - - async def test_empty(self): - """Test .only() with no fields (should raise an error)""" - with self.assertRaises(ValueError): - await Event.all().only() - - async def test_annotate(self): - tournaments = await Tournament.annotate(event_count=Count("events")).only( - "name", "event_count" - ) - - self.assertEqual(tournaments[0].name, "Tournament A") - self.assertEqual(tournaments[0].event_count, 2) - with self.assertRaises(AttributeError): - _ = tournaments[0].desc - - async def test_nonexistent_field(self): - """Test .only() with a field that doesn't exist""" - with self.assertRaises(FieldError): - await Event.all().only("nonexistent_field").all() - - async def test_join_in_filter(self): - event = await Event.filter(name="Event 1").only("name").first() - self.assertEqual(event.name, "Event 1") - with self.assertRaises(AttributeError): - _ = event.tournament - - event = await Event.filter(tournament__name="Tournament A").only("name").first() - self.assertEqual(event.name, "Event 1") - with self.assertRaises(AttributeError): - _ = event.tournament - - event = ( - await Event.filter(tournament__name="Tournament A") - .only("name", "tournament__name") - .first() - ) - self.assertEqual(event.name, "Event 1") - self.assertEqual(event.tournament.name, "Tournament A") - - async def test_join_in_order_by(self): - events = await Event.all().order_by("name").only("name") - self.assertEqual(events[0].name, "Event 1") - with self.assertRaises(AttributeError): - _ = events[0].tournament - - events = await Event.all().order_by("tournament__name", "name").only("name") - self.assertEqual(events[0].name, "Event 1") - with self.assertRaises(AttributeError): - _ = events[0].tournament - - events = ( - await Event.all().order_by("tournament__name", "name").only("name", "tournament__name") - ) - self.assertEqual(events[0].name, "Event 1") - self.assertEqual(events[0].tournament.name, "Tournament A") - - async def test_select_related(self): - """Test .only() with .select_related() for basic functionality""" - event = ( - await Event.filter(name="Event 1") - .select_related("tournament") - .only("name", "tournament__name") - .first() - ) - - self.assertEqual(event.name, "Event 1") - self.assertEqual(event.tournament.name, "Tournament A") - - with self.assertRaises(AttributeError): - _ = event.id - with self.assertRaises(AttributeError): - _ = event.tournament.id + +@pytest.mark.asyncio +async def test_only_straight_partial_save_with_pk_wrong_field(db, straight_fields_instance): + instance_part = await StraightFields.get(chars="Test").only("chars", "eyedee") + + with pytest.raises(IncompleteInstanceError, match="field 'nullable' is not available"): + await instance_part.save(update_fields=["nullable"]) + + +@pytest.mark.asyncio +async def test_only_straight_partial_save_with_pk(db, straight_fields_instance): + instance_part = await StraightFields.get(chars="Test").only("chars", "eyedee") + + instance_part.chars = "Test1" + await instance_part.save(update_fields=["chars"]) + + instance2 = await StraightFields.get(pk=straight_fields_instance.pk) + assert instance2.chars == "Test1" + + +# ============================================================================ +# TestOnlySource tests (same as Straight but with SourceFields model) +# ============================================================================ + + +@pytest.mark.asyncio +async def test_only_source_get(db, source_fields_instance): + instance_part = await SourceFields.get(chars="Test").only("chars", "blip") + + assert instance_part.chars == "Test" + with pytest.raises(AttributeError): + _ = instance_part.nullable + + +@pytest.mark.asyncio +async def test_only_source_filter(db, source_fields_instance): + instances = await SourceFields.filter(chars="Test").only("chars", "blip") + + assert len(instances) == 1 + assert instances[0].chars == "Test" + with pytest.raises(AttributeError): + _ = instances[0].nullable + + +@pytest.mark.asyncio +async def test_only_source_first(db, source_fields_instance): + instance_part = await SourceFields.filter(chars="Test").only("chars", "blip").first() + + assert instance_part.chars == "Test" + with pytest.raises(AttributeError): + _ = instance_part.nullable + + +@pytest.mark.asyncio +async def test_only_source_save(db, source_fields_instance): + instance_part = await SourceFields.get(chars="Test").only("chars", "blip") + + with pytest.raises(IncompleteInstanceError, match=" is a partial model"): + await instance_part.save() + + +@pytest.mark.asyncio +async def test_only_source_partial_save(db, source_fields_instance): + instance_part = await SourceFields.get(chars="Test").only("chars", "blip") + + with pytest.raises(IncompleteInstanceError, match="Partial update not available"): + await instance_part.save(update_fields=["chars"]) + + +@pytest.mark.asyncio +async def test_only_source_partial_save_with_pk_wrong_field(db, source_fields_instance): + instance_part = await SourceFields.get(chars="Test").only("chars", "eyedee") + + with pytest.raises(IncompleteInstanceError, match="field 'nullable' is not available"): + await instance_part.save(update_fields=["nullable"]) + + +@pytest.mark.asyncio +async def test_only_source_partial_save_with_pk(db, source_fields_instance): + instance_part = await SourceFields.get(chars="Test").only("chars", "eyedee") + + instance_part.chars = "Test1" + await instance_part.save(update_fields=["chars"]) + + instance2 = await SourceFields.get(pk=source_fields_instance.pk) + assert instance2.chars == "Test1" + + +# ============================================================================ +# TestOnlyRecursive tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_only_recursive_one_level(db): + left_1st_lvl = await DoubleFK.create(name="1st") + root = await DoubleFK.create(name="root", left=left_1st_lvl) + + ret = await DoubleFK.filter(pk=root.pk).only("name", "left__name", "left__left__name").first() + assert ret is not None + with pytest.raises(AttributeError): + _ = ret.id + assert ret.name == "root" + assert ret.left.name == "1st" + with pytest.raises(AttributeError): + _ = ret.left.id + with pytest.raises(AttributeError): + _ = ret.right + + +@pytest.mark.asyncio +async def test_only_recursive_two_levels(db): + left_2nd_lvl = await DoubleFK.create(name="second leaf") + left_1st_lvl = await DoubleFK.create(name="1st", left=left_2nd_lvl) + root = await DoubleFK.create(name="root", left=left_1st_lvl) + + ret = await DoubleFK.filter(pk=root.pk).only("name", "left__name", "left__left__name").first() + assert ret is not None + with pytest.raises(AttributeError): + _ = ret.id + assert ret.name == "root" + assert ret.left.name == "1st" + with pytest.raises(AttributeError): + _ = ret.left.id + assert ret.left.left.name == "second leaf" + + +@pytest.mark.asyncio +async def test_only_recursive_two_levels_reverse_argument_order(db): + left_2nd_lvl = await DoubleFK.create(name="second leaf") + left_1st_lvl = await DoubleFK.create(name="1st", left=left_2nd_lvl) + root = await DoubleFK.create(name="root", left=left_1st_lvl) + + ret = await DoubleFK.filter(pk=root.pk).only("left__left__name", "left__name", "name").first() + assert ret is not None + with pytest.raises(AttributeError): + _ = ret.id + assert ret.name == "root" + assert ret.left.name == "1st" + with pytest.raises(AttributeError): + _ = ret.left.id + assert ret.left.left.name == "second leaf" + + +# ============================================================================ +# TestOnlyRelated tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_only_related_one_level(db): + tournament = await Tournament.create(name="New Tournament", desc="New Description") + await Event.create(name="Event 1", tournament=tournament) + await Event.create(name="Event 2", tournament=tournament) + + ret = ( + await Event.filter(tournament=tournament).only("name", "tournament__name").order_by("name") + ) + assert len(ret) == 2 + assert ret[0].name == "Event 1" + with pytest.raises(AttributeError): + _ = ret[0].alias + assert ret[1].name == "Event 2" + with pytest.raises(AttributeError): + _ = ret[1].alias + assert ret[0].tournament.name == "New Tournament" + with pytest.raises(AttributeError): + _ = ret[0].tournament.id + with pytest.raises(AttributeError): + _ = ret[0].tournament.desc + + +@pytest.mark.asyncio +async def test_only_related_one_level_reversed_argument_order(db): + tournament = await Tournament.create(name="New Tournament", desc="New Description") + await Event.create(name="Event 1", tournament=tournament) + await Event.create(name="Event 2", tournament=tournament) + + ret = ( + await Event.filter(tournament=tournament).only("tournament__name", "name").order_by("name") + ) + assert len(ret) == 2 + assert ret[0].name == "Event 1" + assert ret[0].tournament.name == "New Tournament" + + +@pytest.mark.asyncio +async def test_only_related_just_related(db): + tournament = await Tournament.create(name="New Tournament", desc="New Description") + await Event.create(name="Event 1", tournament=tournament) + await Event.create(name="Event 2", tournament=tournament) + + ret = await Event.filter(tournament=tournament).only("tournament__name").order_by("name").all() + assert len(ret) == 2 + assert ret[0].tournament.name == "New Tournament" + assert ret[1].tournament.name == "New Tournament" + + +# ============================================================================ +# Fixture for TestOnlyAdvanced tests +# ============================================================================ + + +@pytest_asyncio.fixture +async def tournament_with_events(db): + """Create a tournament with two events for advanced tests.""" + tournament = await Tournament.create(name="Tournament A", desc="Description A") + event1 = await Event.create(name="Event 1", tournament=tournament) + event2 = await Event.create(name="Event 2", tournament=tournament) + return tournament, event1, event2 + + +# ============================================================================ +# TestOnlyAdvanced tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_only_advanced_exclude(db, tournament_with_events): + """Test .only() combined with .exclude()""" + tournament, event1, event2 = tournament_with_events + events = await Event.filter(tournament=tournament).exclude(name="Event 2").only("name") + assert len(events) == 1 + assert events[0].name == "Event 1" + with pytest.raises(AttributeError): + _ = events[0].modified + + +@pytest.mark.asyncio +async def test_only_advanced_limit(db, tournament_with_events): + """Test .only() combined with .limit()""" + events = await Event.all().only("name").limit(1) + assert len(events) == 1 + assert events[0].name == "Event 1" # Assumes ordering by PK + with pytest.raises(AttributeError): + _ = events[0].modified + + +@pytest.mark.asyncio +async def test_only_advanced_distinct(db, tournament_with_events): + """Test .only() combined with .distinct()""" + tournament, event1, event2 = tournament_with_events + # Create duplicate event names + await Event.create(name="Event 1", tournament=tournament) + + events = await Event.all().only("name").distinct() + # Should only have 2 distinct event names + assert len(events) == 2 + event_names = {e.name for e in events} + assert event_names == {"Event 1", "Event 2"} + + +@pytest.mark.asyncio +async def test_only_advanced_values(db, tournament_with_events): + """Test .only() combined with .values()""" + with pytest.raises(ValueError): + await Event.all().only("name").values("name") + + +@pytest.mark.asyncio +async def test_only_advanced_pk_field(db, tournament_with_events): + """Test .only() with just the primary key field""" + tournament = await Tournament.first().only("id") + assert tournament.id is not None + with pytest.raises(AttributeError): + _ = tournament.name + + +@pytest.mark.asyncio +async def test_only_advanced_empty(db, tournament_with_events): + """Test .only() with no fields (should raise an error)""" + with pytest.raises(ValueError): + await Event.all().only() + + +@pytest.mark.asyncio +async def test_only_advanced_annotate(db, tournament_with_events): + tournaments = await Tournament.annotate(event_count=Count("events")).only("name", "event_count") + + assert tournaments[0].name == "Tournament A" + assert tournaments[0].event_count == 2 + with pytest.raises(AttributeError): + _ = tournaments[0].desc + + +@pytest.mark.asyncio +async def test_only_advanced_nonexistent_field(db, tournament_with_events): + """Test .only() with a field that doesn't exist""" + with pytest.raises(FieldError): + await Event.all().only("nonexistent_field").all() + + +@pytest.mark.asyncio +async def test_only_advanced_join_in_filter(db, tournament_with_events): + event = await Event.filter(name="Event 1").only("name").first() + assert event.name == "Event 1" + with pytest.raises(AttributeError): + _ = event.tournament + + event = await Event.filter(tournament__name="Tournament A").only("name").first() + assert event.name == "Event 1" + with pytest.raises(AttributeError): + _ = event.tournament + + event = ( + await Event.filter(tournament__name="Tournament A").only("name", "tournament__name").first() + ) + assert event.name == "Event 1" + assert event.tournament.name == "Tournament A" + + +@pytest.mark.asyncio +async def test_only_advanced_join_in_order_by(db, tournament_with_events): + events = await Event.all().order_by("name").only("name") + assert events[0].name == "Event 1" + with pytest.raises(AttributeError): + _ = events[0].tournament + + events = await Event.all().order_by("tournament__name", "name").only("name") + assert events[0].name == "Event 1" + with pytest.raises(AttributeError): + _ = events[0].tournament + + events = await Event.all().order_by("tournament__name", "name").only("name", "tournament__name") + assert events[0].name == "Event 1" + assert events[0].tournament.name == "Tournament A" + + +@pytest.mark.asyncio +async def test_only_advanced_select_related(db, tournament_with_events): + """Test .only() with .select_related() for basic functionality""" + event = ( + await Event.filter(name="Event 1") + .select_related("tournament") + .only("name", "tournament__name") + .first() + ) + + assert event.name == "Event 1" + assert event.tournament.name == "Tournament A" + + with pytest.raises(AttributeError): + _ = event.id + with pytest.raises(AttributeError): + _ = event.tournament.id diff --git a/tests/test_order_by.py b/tests/test_order_by.py index 8f8f1c051..0b80bb96a 100644 --- a/tests/test_order_by.py +++ b/tests/test_order_by.py @@ -1,3 +1,5 @@ +import pytest + from tests.testmodels import ( DefaultOrdered, DefaultOrderedDesc, @@ -12,169 +14,202 @@ from tortoise.expressions import Case, Q, When from tortoise.functions import Count, Lower, Sum +# ============================================================================ +# TestOrderBy tests +# ============================================================================ -class TestOrderBy(test.TestCase): - async def test_order_by(self): - await Tournament.create(name="1") - await Tournament.create(name="2") - tournaments = await Tournament.all().order_by("name") - self.assertEqual([t.name for t in tournaments], ["1", "2"]) +@pytest.mark.asyncio +async def test_order_by(db): + await Tournament.create(name="1") + await Tournament.create(name="2") - async def test_order_by_reversed(self): - await Tournament.create(name="1") - await Tournament.create(name="2") + tournaments = await Tournament.all().order_by("name") + assert [t.name for t in tournaments] == ["1", "2"] - tournaments = await Tournament.all().order_by("-name") - self.assertEqual([t.name for t in tournaments], ["2", "1"]) - async def test_order_by_related(self): - tournament_first = await Tournament.create(name="1") - tournament_second = await Tournament.create(name="2") - await Event.create(name="b", tournament=tournament_first) - await Event.create(name="a", tournament=tournament_second) +@pytest.mark.asyncio +async def test_order_by_reversed(db): + await Tournament.create(name="1") + await Tournament.create(name="2") - tournaments = await Tournament.all().order_by("events__name") - self.assertEqual([t.name for t in tournaments], ["2", "1"]) + tournaments = await Tournament.all().order_by("-name") + assert [t.name for t in tournaments] == ["2", "1"] - async def test_order_by_ambigious_field_name(self): - tournament_first = await Tournament.create(name="Tournament 1", desc="d1") - tournament_second = await Tournament.create(name="Tournament 2", desc="d2") - event_third = await Event.create(name="3", tournament=tournament_second) - event_second = await Event.create(name="2", tournament=tournament_first) - event_first = await Event.create(name="1", tournament=tournament_first) +@pytest.mark.asyncio +async def test_order_by_related(db): + tournament_first = await Tournament.create(name="1") + tournament_second = await Tournament.create(name="2") + await Event.create(name="b", tournament=tournament_first) + await Event.create(name="a", tournament=tournament_second) - res = await Event.all().order_by("tournament__name", "name") - self.assertEqual(res, [event_first, event_second, event_third]) + tournaments = await Tournament.all().order_by("events__name") + assert [t.name for t in tournaments] == ["2", "1"] - async def test_order_by_related_reversed(self): - tournament_first = await Tournament.create(name="1") - tournament_second = await Tournament.create(name="2") - await Event.create(name="b", tournament=tournament_first) - await Event.create(name="a", tournament=tournament_second) - tournaments = await Tournament.all().order_by("-events__name") - self.assertEqual([t.name for t in tournaments], ["1", "2"]) +@pytest.mark.asyncio +async def test_order_by_ambigious_field_name(db): + tournament_first = await Tournament.create(name="Tournament 1", desc="d1") + tournament_second = await Tournament.create(name="Tournament 2", desc="d2") + + event_third = await Event.create(name="3", tournament=tournament_second) + event_second = await Event.create(name="2", tournament=tournament_first) + event_first = await Event.create(name="1", tournament=tournament_first) + + res = await Event.all().order_by("tournament__name", "name") + assert res == [event_first, event_second, event_third] - async def test_order_by_relation(self): - with self.assertRaises(FieldError): - tournament_first = await Tournament.create(name="1") - await Event.create(name="b", tournament=tournament_first) - await Tournament.all().order_by("events") +@pytest.mark.asyncio +async def test_order_by_related_reversed(db): + tournament_first = await Tournament.create(name="1") + tournament_second = await Tournament.create(name="2") + await Event.create(name="b", tournament=tournament_first) + await Event.create(name="a", tournament=tournament_second) - async def test_order_by_unknown_field(self): - with self.assertRaises(FieldError): - tournament_first = await Tournament.create(name="1") - await Event.create(name="b", tournament=tournament_first) + tournaments = await Tournament.all().order_by("-events__name") + assert [t.name for t in tournaments] == ["1", "2"] - await Tournament.all().order_by("something_else") - async def test_order_by_aggregation(self): +@pytest.mark.asyncio +async def test_order_by_relation(db): + with pytest.raises(FieldError): tournament_first = await Tournament.create(name="1") - tournament_second = await Tournament.create(name="2") await Event.create(name="b", tournament=tournament_first) - await Event.create(name="c", tournament=tournament_first) - await Event.create(name="a", tournament=tournament_second) - tournaments = await Tournament.annotate(events_count=Count("events")).order_by( - "events_count" - ) - self.assertEqual([t.name for t in tournaments], ["2", "1"]) + await Tournament.all().order_by("events") + - async def test_order_by_aggregation_reversed(self): +@pytest.mark.asyncio +async def test_order_by_unknown_field(db): + with pytest.raises(FieldError): tournament_first = await Tournament.create(name="1") - tournament_second = await Tournament.create(name="2") await Event.create(name="b", tournament=tournament_first) - await Event.create(name="c", tournament=tournament_first) - await Event.create(name="a", tournament=tournament_second) - tournaments = await Tournament.annotate(events_count=Count("events")).order_by( - "-events_count" - ) - self.assertEqual([t.name for t in tournaments], ["1", "2"]) - - async def test_order_by_reserved_word_annotation(self): - await Tournament.create(name="1") - await Tournament.create(name="2") - - reserved_words = ["order", "group", "limit", "offset", "where"] - - for word in reserved_words: - tournaments = await Tournament.annotate(**{word: Lower("name")}).order_by(word) - self.assertEqual([t.name for t in tournaments], ["1", "2"]) - - async def test_distinct_values_with_annotation(self): - await Tournament.create(name="3") - await Tournament.create(name="1") - await Tournament.create(name="2") - - tournaments = ( - await Tournament.annotate( - name_orderable=Case( - When(Q(name="1"), then="1"), - When(Q(name="2"), then="2"), - When(Q(name="3"), then="3"), - default="-1", - ), - ) - .distinct() - .order_by("name_orderable", "-created") - .values("name", "name_orderable", "created") + await Tournament.all().order_by("something_else") + + +@pytest.mark.asyncio +async def test_order_by_aggregation(db): + tournament_first = await Tournament.create(name="1") + tournament_second = await Tournament.create(name="2") + await Event.create(name="b", tournament=tournament_first) + await Event.create(name="c", tournament=tournament_first) + await Event.create(name="a", tournament=tournament_second) + + tournaments = await Tournament.annotate(events_count=Count("events")).order_by("events_count") + assert [t.name for t in tournaments] == ["2", "1"] + + +@pytest.mark.asyncio +async def test_order_by_aggregation_reversed(db): + tournament_first = await Tournament.create(name="1") + tournament_second = await Tournament.create(name="2") + await Event.create(name="b", tournament=tournament_first) + await Event.create(name="c", tournament=tournament_first) + await Event.create(name="a", tournament=tournament_second) + + tournaments = await Tournament.annotate(events_count=Count("events")).order_by("-events_count") + assert [t.name for t in tournaments] == ["1", "2"] + + +@pytest.mark.asyncio +async def test_order_by_reserved_word_annotation(db): + await Tournament.create(name="1") + await Tournament.create(name="2") + + reserved_words = ["order", "group", "limit", "offset", "where"] + + for word in reserved_words: + tournaments = await Tournament.annotate(**{word: Lower("name")}).order_by(word) + assert [t.name for t in tournaments] == ["1", "2"] + + +@pytest.mark.asyncio +async def test_distinct_values_with_annotation(db): + await Tournament.create(name="3") + await Tournament.create(name="1") + await Tournament.create(name="2") + + tournaments = ( + await Tournament.annotate( + name_orderable=Case( + When(Q(name="1"), then="1"), + When(Q(name="2"), then="2"), + When(Q(name="3"), then="3"), + default="-1", + ), ) - self.assertEqual([t["name"] for t in tournaments], ["1", "2", "3"]) - - async def test_distinct_all_with_annotation(self): - await Tournament.create(name="3") - await Tournament.create(name="1") - await Tournament.create(name="2") - - tournaments = ( - await Tournament.annotate( - name_orderable=Case( - When(Q(name="1"), then="1"), - When(Q(name="2"), then="2"), - When(Q(name="3"), then="3"), - default="-1", - ), - ) - .distinct() - .order_by("name_orderable", "-created") + .distinct() + .order_by("name_orderable", "-created") + .values("name", "name_orderable", "created") + ) + assert [t["name"] for t in tournaments] == ["1", "2", "3"] + + +@pytest.mark.asyncio +async def test_distinct_all_with_annotation(db): + await Tournament.create(name="3") + await Tournament.create(name="1") + await Tournament.create(name="2") + + tournaments = ( + await Tournament.annotate( + name_orderable=Case( + When(Q(name="1"), then="1"), + When(Q(name="2"), then="2"), + When(Q(name="3"), then="3"), + default="-1", + ), ) - self.assertEqual([t.name for t in tournaments], ["1", "2", "3"]) + .distinct() + .order_by("name_orderable", "-created") + ) + assert [t.name for t in tournaments] == ["1", "2", "3"] + + +# ============================================================================ +# TestDefaultOrdering tests +# ============================================================================ + + +@pytest.mark.asyncio +@test.requireCapability(dialect=NotEQ("oracle")) +async def test_default_order(db): + await DefaultOrdered.create(one="2", second=1) + await DefaultOrdered.create(one="1", second=1) + + instance_list = await DefaultOrdered.all() + assert [i.one for i in instance_list] == ["1", "2"] -class TestDefaultOrdering(test.TestCase): - @test.requireCapability(dialect=NotEQ("oracle")) - async def test_default_order(self): - await DefaultOrdered.create(one="2", second=1) - await DefaultOrdered.create(one="1", second=1) +@pytest.mark.asyncio +@test.requireCapability(dialect=NotEQ("oracle")) +async def test_default_order_desc(db): + await DefaultOrderedDesc.create(one="1", second=1) + await DefaultOrderedDesc.create(one="2", second=1) - instance_list = await DefaultOrdered.all() - self.assertEqual([i.one for i in instance_list], ["1", "2"]) + instance_list = await DefaultOrderedDesc.all() + assert [i.one for i in instance_list] == ["2", "1"] - @test.requireCapability(dialect=NotEQ("oracle")) - async def test_default_order_desc(self): - await DefaultOrderedDesc.create(one="1", second=1) - await DefaultOrderedDesc.create(one="2", second=1) - instance_list = await DefaultOrderedDesc.all() - self.assertEqual([i.one for i in instance_list], ["2", "1"]) +@pytest.mark.asyncio +async def test_default_order_invalid(db): + await DefaultOrderedInvalid.create(one="1", second=1) + await DefaultOrderedInvalid.create(one="2", second=1) - async def test_default_order_invalid(self): - await DefaultOrderedInvalid.create(one="1", second=1) - await DefaultOrderedInvalid.create(one="2", second=1) + with pytest.raises(ConfigurationError): + await DefaultOrderedInvalid.all() - with self.assertRaises(ConfigurationError): - await DefaultOrderedInvalid.all() - async def test_default_order_annotated_query(self): - instance = await DefaultOrdered.create(one="2", second=1) - await FKToDefaultOrdered.create(link=instance, value=10) - await DefaultOrdered.create(one="1", second=1) +@pytest.mark.asyncio +async def test_default_order_annotated_query(db): + instance = await DefaultOrdered.create(one="2", second=1) + await FKToDefaultOrdered.create(link=instance, value=10) + await DefaultOrdered.create(one="1", second=1) - queryset = DefaultOrdered.all().annotate(res=Sum("related__value")) - queryset._make_query() - query = queryset.query.get_sql() - self.assertTrue("order by" not in query.lower()) + queryset = DefaultOrdered.all().annotate(res=Sum("related__value")) + queryset._make_query() + query = queryset.query.get_sql() + assert "order by" not in query.lower() diff --git a/tests/test_order_by_nested.py b/tests/test_order_by_nested.py index 7ba085e3a..c3b380c01 100644 --- a/tests/test_order_by_nested.py +++ b/tests/test_order_by_nested.py @@ -1,32 +1,33 @@ +import pytest + from tests.testmodels import Event, Tournament from tortoise.contrib import test from tortoise.contrib.test.condition import NotEQ -class TestOrderByNested(test.TestCase): - @test.requireCapability(dialect=NotEQ("oracle")) - async def test_basic(self): - await Event.create( - name="Event 1", tournament=await Tournament.create(name="Tournament 1", desc="B") - ) - await Event.create( - name="Event 2", tournament=await Tournament.create(name="Tournament 2", desc="A") - ) +@test.requireCapability(dialect=NotEQ("oracle")) +@pytest.mark.asyncio +async def test_order_by_nested_basic(db): + await Event.create( + name="Event 1", tournament=await Tournament.create(name="Tournament 1", desc="B") + ) + await Event.create( + name="Event 2", tournament=await Tournament.create(name="Tournament 2", desc="A") + ) - self.assertEqual( - await Event.all().order_by("-name").values("name"), - [{"name": "Event 2"}, {"name": "Event 1"}], - ) + assert await Event.all().order_by("-name").values("name") == [ + {"name": "Event 2"}, + {"name": "Event 1"}, + ] - self.assertEqual( - await Event.all().prefetch_related("tournament").values("tournament__desc"), - [{"tournament__desc": "B"}, {"tournament__desc": "A"}], - ) + assert await Event.all().prefetch_related("tournament").values("tournament__desc") == [ + {"tournament__desc": "B"}, + {"tournament__desc": "A"}, + ] - self.assertEqual( - await Event.all() - .prefetch_related("tournament") - .order_by("tournament__desc") - .values("tournament__desc"), - [{"tournament__desc": "A"}, {"tournament__desc": "B"}], - ) + assert ( + await Event.all() + .prefetch_related("tournament") + .order_by("tournament__desc") + .values("tournament__desc") + ) == [{"tournament__desc": "A"}, {"tournament__desc": "B"}] diff --git a/tests/test_posix_regex_filter.py b/tests/test_posix_regex_filter.py index 836d1cdfa..46c82b3a9 100644 --- a/tests/test_posix_regex_filter.py +++ b/tests/test_posix_regex_filter.py @@ -1,73 +1,67 @@ +import pytest + from tests import testmodels from tortoise.contrib import test -class RegexTestCase(test.TestCase): - async def asyncSetUp(self) -> None: - await super().asyncSetUp() +@test.requireCapability(support_for_posix_regex_queries=True) +@pytest.mark.asyncio +async def test_regex_filter(db): + author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") + assert set( + await testmodels.Author.filter( + name__posix_regex="^Johann [a-zA-Z]+ von Goethe$" + ).values_list("name", flat=True) + ) == {author.name} -class TestPosixRegexFilter(test.TestCase): - @test.requireCapability(support_for_posix_regex_queries=True) - async def test_regex_filter(self): - author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") - self.assertEqual( - set( - await testmodels.Author.filter( - name__posix_regex="^Johann [a-zA-Z]+ von Goethe$" - ).values_list("name", flat=True) - ), - {author.name}, +@test.requireCapability(dialect="postgres", support_for_posix_regex_queries=True) +@pytest.mark.asyncio +async def test_regex_filter_works_with_null_field_postgres(db): + await testmodels.Tournament.create(name="Test") + print(testmodels.Tournament.filter(desc__posix_regex="^test$").sql()) + assert ( + set( + await testmodels.Tournament.filter(desc__posix_regex="^test$").values_list( + "name", flat=True + ) ) + == set() + ) - @test.requireCapability(dialect="postgres", support_for_posix_regex_queries=True) - async def test_regex_filter_works_with_null_field_postgres(self): - await testmodels.Tournament.create(name="Test") - print(testmodels.Tournament.filter(desc__posix_regex="^test$").sql()) - self.assertEqual( - set( - await testmodels.Tournament.filter(desc__posix_regex="^test$").values_list( - "name", flat=True - ) - ), - set(), - ) - @test.requireCapability(dialect="sqlite", support_for_posix_regex_queries=True) - async def test_regex_filter_works_with_null_field_sqlite(self): - await testmodels.Tournament.create(name="Test") - print(testmodels.Tournament.filter(desc__posix_regex="^test$").sql()) - self.assertEqual( - set( - await testmodels.Tournament.filter(desc__posix_regex="^test$").values_list( - "name", flat=True - ) - ), - set(), +@test.requireCapability(dialect="sqlite", support_for_posix_regex_queries=True) +@pytest.mark.asyncio +async def test_regex_filter_works_with_null_field_sqlite(db): + await testmodels.Tournament.create(name="Test") + print(testmodels.Tournament.filter(desc__posix_regex="^test$").sql()) + assert ( + set( + await testmodels.Tournament.filter(desc__posix_regex="^test$").values_list( + "name", flat=True + ) ) + == set() + ) -class TestCaseInsensitivePosixRegexFilter(test.TestCase): - @test.requireCapability(dialect="postgres", support_for_posix_regex_queries=True) - async def test_case_insensitive_regex_filter_postgres(self): - author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") - self.assertEqual( - set( - await testmodels.Author.filter( - name__iposix_regex="^johann [a-zA-Z]+ Von goethe$" - ).values_list("name", flat=True) - ), - {author.name}, - ) +@test.requireCapability(dialect="postgres", support_for_posix_regex_queries=True) +@pytest.mark.asyncio +async def test_case_insensitive_regex_filter_postgres(db): + author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") + assert set( + await testmodels.Author.filter( + name__iposix_regex="^johann [a-zA-Z]+ Von goethe$" + ).values_list("name", flat=True) + ) == {author.name} - @test.requireCapability(dialect="sqlite", support_for_posix_regex_queries=True) - async def test_case_insensitive_regex_filter_sqlite(self): - author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") - self.assertEqual( - set( - await testmodels.Author.filter( - name__iposix_regex="^johann [a-zA-Z]+ Von goethe$" - ).values_list("name", flat=True) - ), - {author.name}, - ) + +@test.requireCapability(dialect="sqlite", support_for_posix_regex_queries=True) +@pytest.mark.asyncio +async def test_case_insensitive_regex_filter_sqlite(db): + author = await testmodels.Author.create(name="Johann Wolfgang von Goethe") + assert set( + await testmodels.Author.filter( + name__iposix_regex="^johann [a-zA-Z]+ Von goethe$" + ).values_list("name", flat=True) + ) == {author.name} diff --git a/tests/test_prefetching.py b/tests/test_prefetching.py index 12b366656..face651cc 100644 --- a/tests/test_prefetching.py +++ b/tests/test_prefetching.py @@ -1,154 +1,175 @@ +import pytest + from tests.testmodels import Address, Event, Team, Tournament -from tortoise.contrib import test from tortoise.exceptions import FieldError, OperationalError from tortoise.functions import Count from tortoise.query_utils import Prefetch -class TestPrefetching(test.TestCase): - async def test_prefetch(self): - tournament = await Tournament.create(name="tournament") - event = await Event.create(name="First", tournament=tournament) - await Event.create(name="Second", tournament=tournament) - team = await Team.create(name="1") - team_second = await Team.create(name="2") - await event.participants.add(team, team_second) - tournament = await Tournament.all().prefetch_related("events__participants").first() - self.assertEqual(len(tournament.events[0].participants), 2) - self.assertEqual(len(tournament.events[1].participants), 0) - - async def test_prefetch_object(self): +@pytest.mark.asyncio +async def test_prefetch(db): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + await Event.create(name="Second", tournament=tournament) + team = await Team.create(name="1") + team_second = await Team.create(name="2") + await event.participants.add(team, team_second) + tournament = await Tournament.all().prefetch_related("events__participants").first() + assert len(tournament.events[0].participants) == 2 + assert len(tournament.events[1].participants) == 0 + + +@pytest.mark.asyncio +async def test_prefetch_object(db): + tournament = await Tournament.create(name="tournament") + await Event.create(name="First", tournament=tournament) + await Event.create(name="Second", tournament=tournament) + tournament_with_filtered = ( + await Tournament.all() + .prefetch_related(Prefetch("events", queryset=Event.filter(name="First"))) + .first() + ) + tournament = await Tournament.first().prefetch_related("events") + assert len(tournament_with_filtered.events) == 1 + assert len(tournament.events) == 2 + + +@pytest.mark.asyncio +async def test_prefetch_unknown_field(db): + with pytest.raises(OperationalError): tournament = await Tournament.create(name="tournament") await Event.create(name="First", tournament=tournament) await Event.create(name="Second", tournament=tournament) - tournament_with_filtered = ( - await Tournament.all() - .prefetch_related(Prefetch("events", queryset=Event.filter(name="First"))) + await ( + Tournament.all() + .prefetch_related(Prefetch("events1", queryset=Event.filter(name="First"))) .first() ) - tournament = await Tournament.first().prefetch_related("events") - self.assertEqual(len(tournament_with_filtered.events), 1) - self.assertEqual(len(tournament.events), 2) - - async def test_prefetch_unknown_field(self): - with self.assertRaises(OperationalError): - tournament = await Tournament.create(name="tournament") - await Event.create(name="First", tournament=tournament) - await Event.create(name="Second", tournament=tournament) - await ( - Tournament.all() - .prefetch_related(Prefetch("events1", queryset=Event.filter(name="First"))) - .first() - ) - - async def test_prefetch_m2m(self): - tournament = await Tournament.create(name="tournament") - event = await Event.create(name="First", tournament=tournament) - team = await Team.create(name="1") - team_second = await Team.create(name="2") - await event.participants.add(team, team_second) - fetched_events = ( - await Event.all() - .prefetch_related(Prefetch("participants", queryset=Team.filter(name="1"))) - .first() - ) - self.assertEqual(len(fetched_events.participants), 1) - - async def test_prefetch_o2o(self): - tournament = await Tournament.create(name="tournament") - event = await Event.create(name="First", tournament=tournament) - await Address.create(city="Santa Monica", street="Ocean", event=event) - - fetched_events = await Event.all().prefetch_related("address").first() - self.assertEqual(fetched_events.address.city, "Santa Monica") - async def test_prefetch_nested(self): - tournament = await Tournament.create(name="tournament") - event = await Event.create(name="First", tournament=tournament) - await Event.create(name="Second", tournament=tournament) - team = await Team.create(name="1") - team_second = await Team.create(name="2") - await event.participants.add(team, team_second) - fetched_tournaments = ( - await Tournament.all() - .prefetch_related( - Prefetch("events", queryset=Event.filter(name="First")), - Prefetch("events__participants", queryset=Team.filter(name="1")), - ) - .first() +@pytest.mark.asyncio +async def test_prefetch_m2m(db): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + team = await Team.create(name="1") + team_second = await Team.create(name="2") + await event.participants.add(team, team_second) + fetched_events = ( + await Event.all() + .prefetch_related(Prefetch("participants", queryset=Team.filter(name="1"))) + .first() + ) + assert len(fetched_events.participants) == 1 + + +@pytest.mark.asyncio +async def test_prefetch_o2o(db): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + await Address.create(city="Santa Monica", street="Ocean", event=event) + + fetched_events = await Event.all().prefetch_related("address").first() + + assert fetched_events.address.city == "Santa Monica" + + +@pytest.mark.asyncio +async def test_prefetch_nested(db): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + await Event.create(name="Second", tournament=tournament) + team = await Team.create(name="1") + team_second = await Team.create(name="2") + await event.participants.add(team, team_second) + fetched_tournaments = ( + await Tournament.all() + .prefetch_related( + Prefetch("events", queryset=Event.filter(name="First")), + Prefetch("events__participants", queryset=Team.filter(name="1")), ) - self.assertEqual(len(fetched_tournaments.events[0].participants), 1) - - async def test_prefetch_nested_with_aggregation(self): - tournament = await Tournament.create(name="tournament") - event = await Event.create(name="First", tournament=tournament) - await Event.create(name="Second", tournament=tournament) - team = await Team.create(name="1") - team_second = await Team.create(name="2") - await event.participants.add(team, team_second) - fetched_tournaments = ( - await Tournament.all() - .prefetch_related( - Prefetch( - "events", queryset=Event.annotate(teams=Count("participants")).filter(teams=2) - ) - ) - .first() - ) - self.assertEqual(len(fetched_tournaments.events), 1) - self.assertEqual(fetched_tournaments.events[0].pk, event.pk) - - async def test_prefetch_direct_relation(self): - tournament = await Tournament.create(name="tournament") - await Event.create(name="First", tournament=tournament) - event = await Event.first().prefetch_related("tournament") - self.assertEqual(event.tournament.id, tournament.id) - - async def test_prefetch_bad_key(self): - tournament = await Tournament.create(name="tournament") - await Event.create(name="First", tournament=tournament) - with self.assertRaisesRegex(FieldError, "Relation tour1nament for models.Event not found"): - await Event.first().prefetch_related("tour1nament") - - async def test_prefetch_m2m_filter(self): - tournament = await Tournament.create(name="tournament") - team = await Team.create(name="1") - team_second = await Team.create(name="2") - event = await Event.create(name="First", tournament=tournament) - await event.participants.add(team, team_second) - event = await Event.first().prefetch_related( - Prefetch("participants", Team.filter(name="2")) - ) - self.assertEqual(len(event.participants), 1) - self.assertEqual(list(event.participants), [team_second]) - - async def test_prefetch_m2m_to_attr(self): - tournament = await Tournament.create(name="tournament") - team = await Team.create(name="1") - team_second = await Team.create(name="2") - event = await Event.create(name="First", tournament=tournament) - await event.participants.add(team, team_second) - event = await Event.first().prefetch_related( - Prefetch("participants", Team.filter(name="1"), to_attr="to_attr_participants_1"), - Prefetch("participants", Team.filter(name="2"), to_attr="to_attr_participants_2"), - ) - self.assertEqual(list(event.to_attr_participants_1), [team]) - self.assertEqual(list(event.to_attr_participants_2), [team_second]) - - async def test_prefetch_o2o_to_attr(self): - tournament = await Tournament.create(name="tournament") - event = await Event.create(name="First", tournament=tournament) - address = await Address.create(city="Santa Monica", street="Ocean", event=event) - event = await Event.get(pk=event.pk).prefetch_related( - Prefetch("address", to_attr="to_address", queryset=Address.all()) - ) - self.assertEqual(address.pk, event.to_address.pk) - - async def test_prefetch_direct_relation_to_attr(self): - tournament = await Tournament.create(name="tournament") - await Event.create(name="First", tournament=tournament) - event = await Event.first().prefetch_related( - Prefetch("tournament", queryset=Tournament.all(), to_attr="to_attr_tournament") + .first() + ) + assert len(fetched_tournaments.events[0].participants) == 1 + + +@pytest.mark.asyncio +async def test_prefetch_nested_with_aggregation(db): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + await Event.create(name="Second", tournament=tournament) + team = await Team.create(name="1") + team_second = await Team.create(name="2") + await event.participants.add(team, team_second) + fetched_tournaments = ( + await Tournament.all() + .prefetch_related( + Prefetch("events", queryset=Event.annotate(teams=Count("participants")).filter(teams=2)) ) - self.assertEqual(event.to_attr_tournament.id, tournament.id) + .first() + ) + assert len(fetched_tournaments.events) == 1 + assert fetched_tournaments.events[0].pk == event.pk + + +@pytest.mark.asyncio +async def test_prefetch_direct_relation(db): + tournament = await Tournament.create(name="tournament") + await Event.create(name="First", tournament=tournament) + event = await Event.first().prefetch_related("tournament") + assert event.tournament.id == tournament.id + + +@pytest.mark.asyncio +async def test_prefetch_bad_key(db): + tournament = await Tournament.create(name="tournament") + await Event.create(name="First", tournament=tournament) + with pytest.raises(FieldError, match="Relation tour1nament for models.Event not found"): + await Event.first().prefetch_related("tour1nament") + + +@pytest.mark.asyncio +async def test_prefetch_m2m_filter(db): + tournament = await Tournament.create(name="tournament") + team = await Team.create(name="1") + team_second = await Team.create(name="2") + event = await Event.create(name="First", tournament=tournament) + await event.participants.add(team, team_second) + event = await Event.first().prefetch_related(Prefetch("participants", Team.filter(name="2"))) + assert len(event.participants) == 1 + assert list(event.participants) == [team_second] + + +@pytest.mark.asyncio +async def test_prefetch_m2m_to_attr(db): + tournament = await Tournament.create(name="tournament") + team = await Team.create(name="1") + team_second = await Team.create(name="2") + event = await Event.create(name="First", tournament=tournament) + await event.participants.add(team, team_second) + event = await Event.first().prefetch_related( + Prefetch("participants", Team.filter(name="1"), to_attr="to_attr_participants_1"), + Prefetch("participants", Team.filter(name="2"), to_attr="to_attr_participants_2"), + ) + assert list(event.to_attr_participants_1) == [team] + assert list(event.to_attr_participants_2) == [team_second] + + +@pytest.mark.asyncio +async def test_prefetch_o2o_to_attr(db): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + address = await Address.create(city="Santa Monica", street="Ocean", event=event) + event = await Event.get(pk=event.pk).prefetch_related( + Prefetch("address", to_attr="to_address", queryset=Address.all()) + ) + assert address.pk == event.to_address.pk + + +@pytest.mark.asyncio +async def test_prefetch_direct_relation_to_attr(db): + tournament = await Tournament.create(name="tournament") + await Event.create(name="First", tournament=tournament) + event = await Event.first().prefetch_related( + Prefetch("tournament", queryset=Tournament.all(), to_attr="to_attr_tournament") + ) + assert event.to_attr_tournament.id == tournament.id diff --git a/tests/test_primary_key.py b/tests/test_primary_key.py index dbe224db4..d4033f5cb 100644 --- a/tests/test_primary_key.py +++ b/tests/test_primary_key.py @@ -3,6 +3,8 @@ import uuid from typing import Any +import pytest + from tests.testmodels import ( CharFkRelatedModel, CharM2MRelatedModel, @@ -13,55 +15,59 @@ UUIDPkModel, ) from tortoise import fields -from tortoise.contrib import test from tortoise.exceptions import ConfigurationError -class TestQueryset(test.TestCase): - async def test_implicit_pk(self): +class TestQueryset: + @pytest.mark.asyncio + async def test_implicit_pk(self, db): instance = await ImplicitPkModel.create(value="test") - self.assertTrue(instance.id) - self.assertEqual(instance.pk, instance.id) + assert instance.id + assert instance.pk == instance.id - async def test_uuid_pk(self): + @pytest.mark.asyncio + async def test_uuid_pk(self, db): value = uuid.uuid4() await UUIDPkModel.create(id=value) instance2 = await UUIDPkModel.get(id=value) - self.assertEqual(instance2.id, value) - self.assertEqual(instance2.pk, value) + assert instance2.id == value + assert instance2.pk == value - async def test_uuid_pk_default(self): + @pytest.mark.asyncio + async def test_uuid_pk_default(self, db): instance1 = await UUIDPkModel.create() - self.assertIsInstance(instance1.id, uuid.UUID) - self.assertEqual(instance1.pk, instance1.pk) + assert isinstance(instance1.id, uuid.UUID) + assert instance1.pk == instance1.pk instance2 = await UUIDPkModel.get(id=instance1.id) - self.assertEqual(instance2.id, instance1.id) - self.assertEqual(instance2.pk, instance1.id) + assert instance2.id == instance1.id + assert instance2.pk == instance1.id - async def test_uuid_pk_fk(self): + @pytest.mark.asyncio + async def test_uuid_pk_fk(self, db): value = uuid.uuid4() instance = await UUIDPkModel.create(id=value) instance2 = await UUIDPkModel.create(id=uuid.uuid4()) await UUIDFkRelatedModel.create(model=instance2) related_instance = await UUIDFkRelatedModel.create(model=instance) - self.assertEqual(related_instance.model_id, value) + assert related_instance.model_id == value related_instance = await UUIDFkRelatedModel.filter(model=instance).first() - self.assertEqual(related_instance.model_id, value) + assert related_instance.model_id == value related_instance = await UUIDFkRelatedModel.filter(model_id=value).first() - self.assertEqual(related_instance.model_id, value) + assert related_instance.model_id == value await related_instance.fetch_related("model") - self.assertEqual(related_instance.model, instance) + assert related_instance.model == instance await instance.fetch_related("children") - self.assertEqual(instance.children[0], related_instance) + assert instance.children[0] == related_instance - async def test_uuid_m2m(self): + @pytest.mark.asyncio + async def test_uuid_m2m(self, db): value = uuid.uuid4() instance = await UUIDPkModel.create(id=value) instance2 = await UUIDPkModel.create(id=uuid.uuid4()) @@ -73,52 +79,55 @@ async def test_uuid_m2m(self): await related_instance2.models.add(instance, instance2) await instance.fetch_related("peers") - self.assertEqual(len(instance.peers), 2) - self.assertEqual(set(instance.peers), {related_instance, related_instance2}) + assert len(instance.peers) == 2 + assert set(instance.peers) == {related_instance, related_instance2} await related_instance.fetch_related("models") - self.assertEqual(len(related_instance.models), 1) - self.assertEqual(related_instance.models[0], instance) + assert len(related_instance.models) == 1 + assert related_instance.models[0] == instance await related_instance2.fetch_related("models") - self.assertEqual(len(related_instance2.models), 2) - self.assertEqual({m.pk for m in related_instance2.models}, {instance.pk, instance2.pk}) + assert len(related_instance2.models) == 2 + assert {m.pk for m in related_instance2.models} == {instance.pk, instance2.pk} related_instance_list = await UUIDM2MRelatedModel.filter(models=instance2) - self.assertEqual(len(related_instance_list), 1) - self.assertEqual(related_instance_list[0], related_instance2) + assert len(related_instance_list) == 1 + assert related_instance_list[0] == related_instance2 related_instance_list = await UUIDM2MRelatedModel.filter(models__in=[instance2]) - self.assertEqual(len(related_instance_list), 1) - self.assertEqual(related_instance_list[0], related_instance2) + assert len(related_instance_list) == 1 + assert related_instance_list[0] == related_instance2 - async def test_char_pk(self): + @pytest.mark.asyncio + async def test_char_pk(self, db): value = "Da-PK" await CharPkModel.create(id=value) instance2 = await CharPkModel.get(id=value) - self.assertEqual(instance2.id, value) - self.assertEqual(instance2.pk, value) + assert instance2.id == value + assert instance2.pk == value - async def test_char_pk_fk(self): + @pytest.mark.asyncio + async def test_char_pk_fk(self, db): value = "Da-PK-for-FK" instance = await CharPkModel.create(id=value) instance2 = await CharPkModel.create(id=uuid.uuid4()) await CharFkRelatedModel.create(model=instance2) related_instance = await CharFkRelatedModel.create(model=instance) - self.assertEqual(related_instance.model_id, value) + assert related_instance.model_id == value related_instance = await CharFkRelatedModel.filter(model=instance).first() - self.assertEqual(related_instance.model_id, value) + assert related_instance.model_id == value related_instance = await CharFkRelatedModel.filter(model_id=value).first() - self.assertEqual(related_instance.model_id, value) + assert related_instance.model_id == value await instance.fetch_related("children") - self.assertEqual(instance.children[0], related_instance) + assert instance.children[0] == related_instance - async def test_char_m2m(self): + @pytest.mark.asyncio + async def test_char_m2m(self, db): value = "Da-PK-for-M2M" instance = await CharPkModel.create(id=value) instance2 = await CharPkModel.create(id=uuid.uuid4()) @@ -130,111 +139,149 @@ async def test_char_m2m(self): await related_instance2.models.add(instance, instance2) await related_instance.fetch_related("models") - self.assertEqual(len(related_instance.models), 1) - self.assertEqual(related_instance.models[0], instance) + assert len(related_instance.models) == 1 + assert related_instance.models[0] == instance await related_instance2.fetch_related("models") - self.assertEqual(len(related_instance2.models), 2) - self.assertEqual({m.pk for m in related_instance2.models}, {instance.pk, instance2.pk}) + assert len(related_instance2.models) == 2 + assert {m.pk for m in related_instance2.models} == {instance.pk, instance2.pk} related_instance_list = await CharM2MRelatedModel.filter(models=instance2) - self.assertEqual(len(related_instance_list), 1) - self.assertEqual(related_instance_list[0], related_instance2) + assert len(related_instance_list) == 1 + assert related_instance_list[0] == related_instance2 related_instance_list = await CharM2MRelatedModel.filter(models__in=[instance2]) - self.assertEqual(len(related_instance_list), 1) - self.assertEqual(related_instance_list[0], related_instance2) + assert len(related_instance_list) == 1 + assert related_instance_list[0] == related_instance2 -class TestPkIndexAlias(test.TestCase): - Field: Any = fields.CharField - init_kwargs = {"max_length": 10} +# Test parameters for pk index alias tests +# Format: (Field class, init_kwargs, field_id) +PK_INDEX_ALIAS_PARAMS = [ + pytest.param(fields.CharField, {"max_length": 10}, id="CharField"), + pytest.param(fields.UUIDField, {}, id="UUIDField"), + pytest.param(fields.IntField, {}, id="IntField"), + pytest.param(fields.BigIntField, {}, id="BigIntField"), + pytest.param(fields.SmallIntField, {}, id="SmallIntField"), +] - async def test_pk_alias_warning(self): + +class TestPkIndexAlias: + """Test pk alias functionality for various field types.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize("Field,init_kwargs", PK_INDEX_ALIAS_PARAMS) + async def test_pk_alias_warning(self, Field: Any, init_kwargs: dict): msg = "`pk` is deprecated, please use `primary_key` instead" - with self.assertWarnsRegex(DeprecationWarning, msg): - f = self.Field(pk=True, **self.init_kwargs) + with pytest.warns(DeprecationWarning, match=msg): + f = Field(pk=True, **init_kwargs) assert f.pk is True - with self.assertWarnsRegex(DeprecationWarning, msg): - f = self.Field(pk=False, **self.init_kwargs) + with pytest.warns(DeprecationWarning, match=msg): + f = Field(pk=False, **init_kwargs) assert f.pk is False - async def test_pk_alias_error(self): - with self.assertRaises(ConfigurationError): - self.Field(pk=True, primary_key=False, **self.init_kwargs) - with self.assertRaises(ConfigurationError): - self.Field(pk=False, primary_key=True, **self.init_kwargs) - - async def test_pk_alias_compare(self): + @pytest.mark.asyncio + @pytest.mark.parametrize("Field,init_kwargs", PK_INDEX_ALIAS_PARAMS) + async def test_pk_alias_error(self, Field: Any, init_kwargs: dict): + with pytest.raises(ConfigurationError): + Field(pk=True, primary_key=False, **init_kwargs) + with pytest.raises(ConfigurationError): + Field(pk=False, primary_key=True, **init_kwargs) + + @pytest.mark.asyncio + @pytest.mark.parametrize("Field,init_kwargs", PK_INDEX_ALIAS_PARAMS) + async def test_pk_alias_compare(self, Field: Any, init_kwargs: dict): # Only for compare, not recommended - f = self.Field(pk=True, primary_key=True, **self.init_kwargs) + f = Field(pk=True, primary_key=True, **init_kwargs) assert f.pk is True - f = self.Field(pk=False, primary_key=False, **self.init_kwargs) + f = Field(pk=False, primary_key=False, **init_kwargs) assert f.pk is False -class TestPkIndexAliasUUID(TestPkIndexAlias): - Field: Any = fields.UUIDField - init_kwargs = {} +class TestPkIndexAliasUUID: + """UUID-specific pk alias tests.""" + @pytest.mark.asyncio async def test_default(self): msg = "`pk` is deprecated, please use `primary_key` instead" - with self.assertWarnsRegex(DeprecationWarning, msg): - f = self.Field(pk=True) + with pytest.warns(DeprecationWarning, match=msg): + f = fields.UUIDField(pk=True) assert f.default == uuid.uuid4 - f = self.Field(primary_key=True) + f = fields.UUIDField(primary_key=True) assert f.default == uuid.uuid4 - f = self.Field() + f = fields.UUIDField() assert f.default is None - f = self.Field(default=1) + f = fields.UUIDField(default=1) assert f.default == 1 -class TestPkIndexAliasInt(TestPkIndexAlias): - Field: Any = fields.IntField - init_kwargs = {} +# Int field types that support positional pk argument +INT_FIELD_TYPES = [ + pytest.param(fields.IntField, id="IntField"), + pytest.param(fields.BigIntField, id="BigIntField"), + pytest.param(fields.SmallIntField, id="SmallIntField"), +] - async def test_argument(self): - f = self.Field(True) - assert f.pk is True - f = self.Field(False) - assert f.pk is False +class TestPkIndexAliasInt: + """Int field types support positional pk argument.""" -class TestPkIndexAliasBigInt(TestPkIndexAliasInt): - Field = fields.BigIntField - + @pytest.mark.asyncio + @pytest.mark.parametrize("Field", INT_FIELD_TYPES) + async def test_argument(self, Field: Any): + f = Field(True) + assert f.pk is True + f = Field(False) + assert f.pk is False -class TestPkIndexAliasSmallInt(TestPkIndexAliasInt): - Field = fields.SmallIntField +class TestPkIndexAliasText: + """TextField pk alias tests with deprecation warnings.""" -class TestPkIndexAliasText(TestPkIndexAlias): - Field = fields.TextField message = "TextField as a PrimaryKey is Deprecated, use CharField instead" + pk_deprecation_msg = "`pk` is deprecated, please use `primary_key` instead" def test_warning(self): - with self.assertWarnsRegex(DeprecationWarning, self.message): - f = self.Field(pk=True) + # TextField(pk=True) emits both warnings: pk deprecation and TextField as PK + with pytest.warns(DeprecationWarning, match=self.message): + f = fields.TextField(pk=True) assert f.pk is True - with self.assertWarnsRegex(DeprecationWarning, self.message): - f = self.Field(primary_key=True) + with pytest.warns(DeprecationWarning, match=self.message): + f = fields.TextField(primary_key=True) assert f.pk is True - with self.assertWarnsRegex(DeprecationWarning, self.message): - f = self.Field(True) + # Positional arg goes to primary_key, so only TextField as PK warning + with pytest.warns(DeprecationWarning, match=self.message): + f = fields.TextField(True) assert f.pk is True - async def test_pk_alias_error(self): - with self.assertRaises(ConfigurationError): - with self.assertWarnsRegex(DeprecationWarning, self.message): - self.Field(pk=True, primary_key=False, **self.init_kwargs) - with self.assertRaises(ConfigurationError): - with self.assertWarnsRegex(DeprecationWarning, self.message): - self.Field(pk=False, primary_key=True, **self.init_kwargs) + @pytest.mark.asyncio + async def test_pk_alias_warning(self): + # TextField(pk=True) emits TextField as PK warning (and pk deprecation) + with pytest.warns(DeprecationWarning, match=self.message): + f = fields.TextField(pk=True) + assert f.pk is True + # pk=False does not trigger TextField as PK warning, but triggers pk deprecation + with pytest.warns(DeprecationWarning, match=self.pk_deprecation_msg): + f = fields.TextField(pk=False) + assert f.pk is False + @pytest.mark.asyncio + async def test_pk_alias_error(self): + # Both pk=True and primary_key=False triggers TextField warning first, then raises + with pytest.raises(ConfigurationError): + with pytest.warns(DeprecationWarning, match=self.message): + fields.TextField(pk=True, primary_key=False) + # pk=False and primary_key=True triggers TextField as PK warning, then raises + with pytest.raises(ConfigurationError): + with pytest.warns(DeprecationWarning, match=self.message): + fields.TextField(pk=False, primary_key=True) + + @pytest.mark.asyncio async def test_pk_alias_compare(self): - with self.assertWarnsRegex(DeprecationWarning, self.message): - f = self.Field(pk=True, primary_key=True, **self.init_kwargs) + # Both pk=True and primary_key=True: TextField as PK warning is emitted + with pytest.warns(DeprecationWarning, match=self.message): + f = fields.TextField(pk=True, primary_key=True) assert f.pk is True - f = self.Field(pk=False, primary_key=False, **self.init_kwargs) + # pk=False and primary_key=False: no warnings + f = fields.TextField(pk=False, primary_key=False) assert f.pk is False diff --git a/tests/test_q.py b/tests/test_q.py index e25eb6907..28bad59c3 100644 --- a/tests/test_q.py +++ b/tests/test_q.py @@ -1,266 +1,302 @@ import operator -from unittest import TestCase as _TestCase +import pytest from pypika_tortoise.context import DEFAULT_SQL_CONTEXT from tests.testmodels import CharFields, IntFields -from tortoise.contrib.test import TestCase from tortoise.exceptions import OperationalError from tortoise.expressions import F, Q, ResolveContext +# ============================================================================= +# Tests for Q object basic operations (no database needed) +# ============================================================================= -class TestQ(_TestCase): - def test_q_basic(self): - q = Q(moo="cow") - self.assertEqual(q.children, ()) - self.assertEqual(q.filters, {"moo": "cow"}) - self.assertEqual(q.join_type, "AND") - - def test_q_compound(self): - q1 = Q(moo="cow") - q2 = Q(moo="bull") - q = Q(q1, q2, join_type=Q.OR) - - self.assertEqual(q1.children, ()) - self.assertEqual(q1.filters, {"moo": "cow"}) - self.assertEqual(q1.join_type, "AND") - - self.assertEqual(q2.children, ()) - self.assertEqual(q2.filters, {"moo": "bull"}) - self.assertEqual(q2.join_type, "AND") - - self.assertEqual(q.children, (q1, q2)) - self.assertEqual(q.filters, {}) - self.assertEqual(q.join_type, "OR") - - def test_q_compound_or(self): - q1 = Q(moo="cow") - q2 = Q(moo="bull") - q = q1 | q2 - - self.assertEqual(q1.children, ()) - self.assertEqual(q1.filters, {"moo": "cow"}) - self.assertEqual(q1.join_type, "AND") - - self.assertEqual(q2.children, ()) - self.assertEqual(q2.filters, {"moo": "bull"}) - self.assertEqual(q2.join_type, "AND") - - self.assertEqual(q.children, (q1, q2)) - self.assertEqual(q.filters, {}) - self.assertEqual(q.join_type, "OR") - - def test_q_compound_and(self): - q1 = Q(moo="cow") - q2 = Q(moo="bull") - q = q1 & q2 - - self.assertEqual(q1.children, ()) - self.assertEqual(q1.filters, {"moo": "cow"}) - self.assertEqual(q1.join_type, "AND") - - self.assertEqual(q2.children, ()) - self.assertEqual(q2.filters, {"moo": "bull"}) - self.assertEqual(q2.join_type, "AND") - - self.assertEqual(q.children, (q1, q2)) - self.assertEqual(q.filters, {}) - self.assertEqual(q.join_type, "AND") - - def test_q_compound_or_notq(self): - with self.assertRaisesRegex(OperationalError, "OR operation requires a Q node"): - Q() | 2 # pylint: disable=W0106 - - def test_q_compound_and_notq(self): - with self.assertRaisesRegex(OperationalError, "AND operation requires a Q node"): - Q() & 2 # pylint: disable=W0106 - - def test_q_notq(self): - with self.assertRaisesRegex(OperationalError, "All ordered arguments must be Q nodes"): - Q(Q(), 1) - - def test_q_bad_join_type(self): - with self.assertRaisesRegex(OperationalError, "join_type must be AND or OR"): - Q(join_type=3) - - def test_q_partial_and(self): - q = Q(join_type="AND", moo="cow") - self.assertEqual(q.children, ()) - self.assertEqual(q.filters, {"moo": "cow"}) - self.assertEqual(q.join_type, "AND") - - def test_q_partial_or(self): - q = Q(join_type="OR", moo="cow") - self.assertEqual(q.children, ()) - self.assertEqual(q.filters, {"moo": "cow"}) - self.assertEqual(q.join_type, "OR") - - def test_q_equality(self): - # basic query - basic_q1 = Q(moo="cow") - basic_q2 = Q(moo="cow") - self.assertEqual(basic_q1, basic_q2) - - # and query - and_q1 = Q(firstname="John") & Q(lastname="Doe") - and_q2 = Q(firstname="John") & Q(lastname="Doe") - self.assertEqual(and_q1, and_q2) - - # or query - or_q1 = Q(firstname="John") | Q(lastname="Doe") - or_q2 = Q(firstname="John") | Q(lastname="Doe") - self.assertEqual(or_q1, or_q2) - - # complex query - complex_q1 = (Q(firstname="John") & Q(lastname="Doe")) | Q(mother_name="Jane") - complex_q2 = (Q(firstname="John") & Q(lastname="Doe")) | Q(mother_name="Jane") - self.assertEqual(complex_q1, complex_q2) - - -class TestQCall(TestCase): - def setUp(self) -> None: - super().setUp() - self.int_fields_context = ResolveContext( - model=IntFields, - table=IntFields._meta.basequery, # type:ignore[arg-type] - annotations={}, - custom_filters={}, - ) - self.char_fields_context = ResolveContext( - model=CharFields, - table=CharFields._meta.basequery, # type:ignore[arg-type] - annotations={}, - custom_filters={}, - ) - def test_q_basic(self): - q = Q(id=8) - r = q.resolve(self.int_fields_context) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id"=8') - - def test_q_basic_and(self): - q = Q(join_type="AND", id=8) - r = q.resolve(self.int_fields_context) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id"=8') - - def test_q_basic_or(self): - q = Q(join_type="OR", id=8) - r = q.resolve(self.int_fields_context) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id"=8') - - def test_q_multiple_and(self): - q = Q(join_type="AND", id__gt=8, id__lt=10) - r = q.resolve(self.int_fields_context) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id">8 AND "id"<10') - - def test_q_multiple_or(self): - q = Q(join_type="OR", id__gt=8, id__lt=10) - r = q.resolve(self.int_fields_context) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id">8 OR "id"<10') - - def test_q_multiple_and2(self): - q = Q(join_type="AND", id=8, intnum=80) - r = q.resolve(self.int_fields_context) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id"=8 AND "intnum"=80') - - def test_q_multiple_or2(self): - q = Q(join_type="OR", id=8, intnum=80) - r = q.resolve(self.int_fields_context) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id"=8 OR "intnum"=80') - - def test_q_complex_int(self): - q = Q(Q(intnum=80), Q(id__lt=5, id__gt=50, join_type="OR"), join_type="AND") - r = q.resolve(self.int_fields_context) - self.assertEqual( - r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"intnum"=80 AND ("id"<5 OR "id">50)' - ) +def test_q_basic(): + q = Q(moo="cow") + assert q.children == () + assert q.filters == {"moo": "cow"} + assert q.join_type == "AND" - def test_q_complex_int2(self): - q = Q(Q(intnum="80"), Q(Q(id__lt="5"), Q(id__gt="50"), join_type="OR"), join_type="AND") - r = q.resolve(self.int_fields_context) - self.assertEqual( - r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"intnum"=80 AND ("id"<5 OR "id">50)' - ) - def test_q_complex_int3(self): - q = Q(Q(id__lt=5, id__gt=50, join_type="OR"), join_type="AND", intnum=80) - r = q.resolve(self.int_fields_context) - self.assertEqual( - r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"intnum"=80 AND ("id"<5 OR "id">50)' - ) +def test_q_compound(): + q1 = Q(moo="cow") + q2 = Q(moo="bull") + q = Q(q1, q2, join_type=Q.OR) - def test_q_complex_char(self): - q = Q(Q(char_null=80), ~Q(char__lt=5, char__gt=50, join_type="OR"), join_type="AND") - r = q.resolve(self.char_fields_context) - self.assertEqual( - r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), - "\"char_null\"='80' AND NOT (\"char\"<'5' OR \"char\">'50')", - ) + assert q1.children == () + assert q1.filters == {"moo": "cow"} + assert q1.join_type == "AND" - def test_q_complex_char2(self): - q = Q( - Q(char_null="80"), - ~Q(Q(char__lt="5"), Q(char__gt="50"), join_type="OR"), - join_type="AND", - ) - r = q.resolve(self.char_fields_context) - self.assertEqual( - r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), - "\"char_null\"='80' AND NOT (\"char\"<'5' OR \"char\">'50')", - ) + assert q2.children == () + assert q2.filters == {"moo": "bull"} + assert q2.join_type == "AND" + + assert q.children == (q1, q2) + assert q.filters == {} + assert q.join_type == "OR" + + +def test_q_compound_or(): + q1 = Q(moo="cow") + q2 = Q(moo="bull") + q = q1 | q2 + + assert q1.children == () + assert q1.filters == {"moo": "cow"} + assert q1.join_type == "AND" + + assert q2.children == () + assert q2.filters == {"moo": "bull"} + assert q2.join_type == "AND" + + assert q.children == (q1, q2) + assert q.filters == {} + assert q.join_type == "OR" + + +def test_q_compound_and(): + q1 = Q(moo="cow") + q2 = Q(moo="bull") + q = q1 & q2 + + assert q1.children == () + assert q1.filters == {"moo": "cow"} + assert q1.join_type == "AND" + + assert q2.children == () + assert q2.filters == {"moo": "bull"} + assert q2.join_type == "AND" + + assert q.children == (q1, q2) + assert q.filters == {} + assert q.join_type == "AND" + + +def test_q_compound_or_notq(): + with pytest.raises(OperationalError, match="OR operation requires a Q node"): + Q() | 2 # pylint: disable=W0106 + + +def test_q_compound_and_notq(): + with pytest.raises(OperationalError, match="AND operation requires a Q node"): + Q() & 2 # pylint: disable=W0106 + + +def test_q_notq(): + with pytest.raises(OperationalError, match="All ordered arguments must be Q nodes"): + Q(Q(), 1) + + +def test_q_bad_join_type(): + with pytest.raises(OperationalError, match="join_type must be AND or OR"): + Q(join_type=3) + + +def test_q_partial_and(): + q = Q(join_type="AND", moo="cow") + assert q.children == () + assert q.filters == {"moo": "cow"} + assert q.join_type == "AND" + + +def test_q_partial_or(): + q = Q(join_type="OR", moo="cow") + assert q.children == () + assert q.filters == {"moo": "cow"} + assert q.join_type == "OR" + + +def test_q_equality(): + # basic query + basic_q1 = Q(moo="cow") + basic_q2 = Q(moo="cow") + assert basic_q1 == basic_q2 + + # and query + and_q1 = Q(firstname="John") & Q(lastname="Doe") + and_q2 = Q(firstname="John") & Q(lastname="Doe") + assert and_q1 == and_q2 + + # or query + or_q1 = Q(firstname="John") | Q(lastname="Doe") + or_q2 = Q(firstname="John") | Q(lastname="Doe") + assert or_q1 == or_q2 + + # complex query + complex_q1 = (Q(firstname="John") & Q(lastname="Doe")) | Q(mother_name="Jane") + complex_q2 = (Q(firstname="John") & Q(lastname="Doe")) | Q(mother_name="Jane") + assert complex_q1 == complex_q2 + + +# ============================================================================= +# Tests for Q object resolution (requires database for model resolution) +# ============================================================================= + + +@pytest.fixture +def int_fields_context(db): + """Context for IntFields model resolution.""" + return ResolveContext( + model=IntFields, + table=IntFields._meta.basequery, + annotations={}, + custom_filters={}, + ) - def test_q_complex_char3(self): - q = Q(~Q(char__lt=5, char__gt=50, join_type="OR"), join_type="AND", char_null=80) - r = q.resolve(self.char_fields_context) - self.assertEqual( - r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), - "\"char_null\"='80' AND NOT (\"char\"<'5' OR \"char\">'50')", - ) - def test_q_with_blank_and(self): - q = Q(Q(id__gt=5), Q(), join_type=Q.AND) - r = q.resolve(self.char_fields_context) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id">5') - - def test_q_with_blank_or(self): - q = Q(Q(id__gt=5), Q(), join_type=Q.OR) - r = q.resolve(self.char_fields_context) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id">5') - - def test_q_with_blank_and2(self): - q = Q(id__gt=5) & Q() - r = q.resolve(self.char_fields_context) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id">5') - - def test_q_with_blank_or2(self): - q = Q(id__gt=5) | Q() - r = q.resolve(self.char_fields_context) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id">5') - - def test_q_with_blank_and3(self): - q = Q() & Q(id__gt=5) - r = q.resolve(self.char_fields_context) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id">5') - - def test_q_with_blank_or3(self): - q = Q() | Q(id__gt=5) - r = q.resolve(self.char_fields_context) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id">5') - - def test_annotations_resolved(self): - q = Q(id__gt=5) | Q(annotated__lt=5) - r = q.resolve( - ResolveContext( - model=IntFields, - table=IntFields._meta.basequery, - annotations={"annotated": F("intnum")}, - custom_filters={ - "annotated__lt": { - "field": "annotated", - "source_field": "annotated", - "operator": operator.lt, - } - }, - ) +@pytest.fixture +def char_fields_context(db): + """Context for CharFields model resolution.""" + return ResolveContext( + model=CharFields, + table=CharFields._meta.basequery, + annotations={}, + custom_filters={}, + ) + + +def test_q_call_basic(int_fields_context): + q = Q(id=8) + r = q.resolve(int_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id"=8' + + +def test_q_call_basic_and(int_fields_context): + q = Q(join_type="AND", id=8) + r = q.resolve(int_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id"=8' + + +def test_q_call_basic_or(int_fields_context): + q = Q(join_type="OR", id=8) + r = q.resolve(int_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id"=8' + + +def test_q_call_multiple_and(int_fields_context): + q = Q(join_type="AND", id__gt=8, id__lt=10) + r = q.resolve(int_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id">8 AND "id"<10' + + +def test_q_call_multiple_or(int_fields_context): + q = Q(join_type="OR", id__gt=8, id__lt=10) + r = q.resolve(int_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id">8 OR "id"<10' + + +def test_q_call_multiple_and2(int_fields_context): + q = Q(join_type="AND", id=8, intnum=80) + r = q.resolve(int_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id"=8 AND "intnum"=80' + + +def test_q_call_multiple_or2(int_fields_context): + q = Q(join_type="OR", id=8, intnum=80) + r = q.resolve(int_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id"=8 OR "intnum"=80' + + +def test_q_call_complex_int(int_fields_context): + q = Q(Q(intnum=80), Q(id__lt=5, id__gt=50, join_type="OR"), join_type="AND") + r = q.resolve(int_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"intnum"=80 AND ("id"<5 OR "id">50)' + + +def test_q_call_complex_int2(int_fields_context): + q = Q(Q(intnum="80"), Q(Q(id__lt="5"), Q(id__gt="50"), join_type="OR"), join_type="AND") + r = q.resolve(int_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"intnum"=80 AND ("id"<5 OR "id">50)' + + +def test_q_call_complex_int3(int_fields_context): + q = Q(Q(id__lt=5, id__gt=50, join_type="OR"), join_type="AND", intnum=80) + r = q.resolve(int_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"intnum"=80 AND ("id"<5 OR "id">50)' + + +def test_q_call_complex_char(char_fields_context): + q = Q(Q(char_null=80), ~Q(char__lt=5, char__gt=50, join_type="OR"), join_type="AND") + r = q.resolve(char_fields_context) + assert ( + r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) + == "\"char_null\"='80' AND NOT (\"char\"<'5' OR \"char\">'50')" + ) + + +def test_q_call_complex_char2(char_fields_context): + q = Q( + Q(char_null="80"), + ~Q(Q(char__lt="5"), Q(char__gt="50"), join_type="OR"), + join_type="AND", + ) + r = q.resolve(char_fields_context) + assert ( + r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) + == "\"char_null\"='80' AND NOT (\"char\"<'5' OR \"char\">'50')" + ) + + +def test_q_call_complex_char3(char_fields_context): + q = Q(~Q(char__lt=5, char__gt=50, join_type="OR"), join_type="AND", char_null=80) + r = q.resolve(char_fields_context) + assert ( + r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) + == "\"char_null\"='80' AND NOT (\"char\"<'5' OR \"char\">'50')" + ) + + +def test_q_call_with_blank_and(char_fields_context): + q = Q(Q(id__gt=5), Q(), join_type=Q.AND) + r = q.resolve(char_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id">5' + + +def test_q_call_with_blank_or(char_fields_context): + q = Q(Q(id__gt=5), Q(), join_type=Q.OR) + r = q.resolve(char_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id">5' + + +def test_q_call_with_blank_and2(char_fields_context): + q = Q(id__gt=5) & Q() + r = q.resolve(char_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id">5' + + +def test_q_call_with_blank_or2(char_fields_context): + q = Q(id__gt=5) | Q() + r = q.resolve(char_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id">5' + + +def test_q_call_with_blank_and3(char_fields_context): + q = Q() & Q(id__gt=5) + r = q.resolve(char_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id">5' + + +def test_q_call_with_blank_or3(char_fields_context): + q = Q() | Q(id__gt=5) + r = q.resolve(char_fields_context) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id">5' + + +def test_q_call_annotations_resolved(db): + q = Q(id__gt=5) | Q(annotated__lt=5) + r = q.resolve( + ResolveContext( + model=IntFields, + table=IntFields._meta.basequery, + annotations={"annotated": F("intnum")}, + custom_filters={ + "annotated__lt": { + "field": "annotated", + "source_field": "annotated", + "operator": operator.lt, + } + }, ) - self.assertEqual(r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT), '"id">5 OR "intnum"<5') + ) + assert r.where_criterion.get_sql(DEFAULT_SQL_CONTEXT) == '"id">5 OR "intnum"<5' diff --git a/tests/test_query_api.py b/tests/test_query_api.py index 26694fc5a..6573f170a 100644 --- a/tests/test_query_api.py +++ b/tests/test_query_api.py @@ -2,6 +2,8 @@ from typing import TypedDict, Union, cast +import pytest +import pytest_asyncio from pydantic import BaseModel, TypeAdapter, ValidationError from pypika_tortoise import Query, Table from pypika_tortoise.context import SqlContext @@ -10,10 +12,10 @@ from typing_extensions import assert_type from tests.testmodels import Tournament -from tortoise import Tortoise, fields +from tortoise import fields from tortoise.connection import connections +from tortoise.context import TortoiseContext, tortoise_test_context from tortoise.contrib import test -from tortoise.contrib.test import SimpleTestCase from tortoise.exceptions import ParamsError from tortoise.models import Model from tortoise.query_api import QueryResult, execute_pypika @@ -34,214 +36,306 @@ class QueryRowDict(TypedDict): name: str -class TestQueryApi(SimpleTestCase): - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - await Tortoise.init(db_url="sqlite://:memory:", modules={"models": [__name__]}) - await Tortoise.generate_schemas() +# ============================================================================= +# Tests for TestQueryApi (formerly SimpleTestCase) +# Uses custom in-memory SQLite initialization +# ============================================================================= + + +@pytest_asyncio.fixture +async def query_api_db(): + """Fixture for QueryApi tests that initializes an in-memory SQLite database.""" + async with tortoise_test_context(modules=[__name__]) as ctx: await QueryModel.create(id=1, name="alpha") await QueryModel.create(id=2, name="beta") + yield ctx - async def _tearDownDB(self) -> None: - await Tortoise.get_connection("default").close() - async def test_execute_pypika(self) -> None: - table = QueryModel.get_table() - query = Query.from_(table).select(table.id, table.name).where(table.name == "alpha") +@pytest.mark.asyncio +async def test_execute_pypika(query_api_db) -> None: + table = QueryModel.get_table() + query = Query.from_(table).select(table.id, table.name).where(table.name == "alpha") - result: QueryResult[dict] = await execute_pypika(query) + result: QueryResult[dict] = await execute_pypika(query) - self.assertEqual(result.rows, [{"id": 1, "name": "alpha"}]) - self.assertEqual(result.rows_affected, 1) + assert result.rows == [{"id": 1, "name": "alpha"}] + assert result.rows_affected == 1 - async def test_execute_pypika_metadata(self) -> None: - table = QueryModel.get_table() - query = Query.from_(table).select(table.id, table.name) - result: QueryResult[dict] = await execute_pypika(query) +@pytest.mark.asyncio +async def test_execute_pypika_metadata(query_api_db) -> None: + table = QueryModel.get_table() + query = Query.from_(table).select(table.id, table.name) - self.assertEqual(result.rows_affected, 2) + result: QueryResult[dict] = await execute_pypika(query) - async def test_execute_pypika_update_rows_affected(self) -> None: - table = QueryModel.get_table() - query = Query.update(table).set(table.name, "gamma").where(table.id == 1) + assert result.rows_affected == 2 - result: QueryResult[dict] = await execute_pypika(query) - self.assertEqual(result.rows, []) - self.assertEqual(result.rows_affected, 1) +@pytest.mark.asyncio +async def test_execute_pypika_update_rows_affected(query_api_db) -> None: + table = QueryModel.get_table() + query = Query.update(table).set(table.name, "gamma").where(table.id == 1) - async def test_execute_pypika_insert_rows_affected(self) -> None: - table = QueryModel.get_table() - query = Query.into(table).columns(table.name).insert("delta") + result: QueryResult[dict] = await execute_pypika(query) - result: QueryResult[dict] = await execute_pypika(query) + assert result.rows == [] + assert result.rows_affected == 1 - self.assertEqual(result.rows, []) - self.assertEqual(result.rows_affected, 1) - async def test_query_parameterization(self) -> None: - table = QueryModel.get_table() - query = Query.from_(table).select(table.id).where(table.name == "alpha") - db = connections.get("default") +@pytest.mark.asyncio +async def test_execute_pypika_insert_rows_affected(query_api_db) -> None: + table = QueryModel.get_table() + query = Query.into(table).columns(table.name).insert("delta") - sql, params = query.get_parameterized_sql(db.query_class.SQL_CONTEXT) + result: QueryResult[dict] = await execute_pypika(query) - self.assertIn("alpha", params) - self.assertNotIn("alpha", sql) + assert result.rows == [] + assert result.rows_affected == 1 - async def test_execute_pypika_pydantic_schema(self) -> None: - table = QueryModel.get_table() - query = Query.from_(table).select(table.id, table.name).where(table.name == "alpha") - result = cast(QueryResult[QueryRow], await execute_pypika(query, schema=QueryRow)) +@pytest.mark.asyncio +async def test_query_parameterization(query_api_db) -> None: + table = QueryModel.get_table() + query = Query.from_(table).select(table.id).where(table.name == "alpha") + db = connections.get("default") - self.assertIsInstance(result.rows[0], QueryRow) - self.assertEqual(result.rows[0].model_dump(), {"id": 1, "name": "alpha"}) + sql, params = query.get_parameterized_sql(db.query_class.SQL_CONTEXT) - async def test_execute_pypika_pydantic_type_adapter(self) -> None: - table = QueryModel.get_table() - query = Query.from_(table).select(table.id, table.name).where(table.name == "alpha") - adapter = TypeAdapter(dict[str, int | str]) + assert "alpha" in params + assert "alpha" not in sql - result: QueryResult[dict[str, Union[int, str]]] = await execute_pypika( # noqa: UP007 - query, - schema=adapter, - ) - self.assertEqual(result.rows, [{"id": 1, "name": "alpha"}]) +@pytest.mark.asyncio +async def test_execute_pypika_pydantic_schema(query_api_db) -> None: + table = QueryModel.get_table() + query = Query.from_(table).select(table.id, table.name).where(table.name == "alpha") - async def test_execute_pypika_typed_dict_schema(self) -> None: - table = QueryModel.get_table() - query = Query.from_(table).select(table.id, table.name).where(table.name == "alpha") + result = cast(QueryResult[QueryRow], await execute_pypika(query, schema=QueryRow)) - result: QueryResult[QueryRowDict] = await execute_pypika(query, schema=QueryRowDict) + assert isinstance(result.rows[0], QueryRow) + assert result.rows[0].model_dump() == {"id": 1, "name": "alpha"} - assert_type(result, QueryResult[QueryRowDict]) - assert_type(result.rows, list[QueryRowDict]) - self.assertEqual(result.rows, [{"id": 1, "name": "alpha"}]) - async def test_execute_pypika_empty_result(self) -> None: - table = QueryModel.get_table() - query = Query.from_(table).select(table.id, table.name).where(table.name == "missing") +@pytest.mark.asyncio +async def test_execute_pypika_pydantic_type_adapter(query_api_db) -> None: + table = QueryModel.get_table() + query = Query.from_(table).select(table.id, table.name).where(table.name == "alpha") + adapter = TypeAdapter(dict[str, int | str]) - result: QueryResult[dict] = await execute_pypika(query) + result: QueryResult[dict[str, Union[int, str]]] = await execute_pypika( # noqa: UP007 + query, + schema=adapter, + ) - self.assertEqual(result.rows, []) - self.assertEqual(result.rows_affected, 0) + assert result.rows == [{"id": 1, "name": "alpha"}] - async def test_execute_pypika_invalid_schema_raises(self) -> None: - table = QueryModel.get_table() - query = Query.from_(table).select(table.name.as_("id"), table.name) - with self.assertRaises(ValidationError): - await execute_pypika(query, schema=QueryRow) +@pytest.mark.asyncio +async def test_execute_pypika_typed_dict_schema(query_api_db) -> None: + table = QueryModel.get_table() + query = Query.from_(table).select(table.id, table.name).where(table.name == "alpha") + result: QueryResult[QueryRowDict] = await execute_pypika(query, schema=QueryRowDict) -class TestQueryApiRowsAffected(test.TestCase): - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - alpha = await Tournament.create(name="alpha") - beta = await Tournament.create(name="beta") - self._alpha_id = alpha.id - self._beta_id = beta.id + assert_type(result, QueryResult[QueryRowDict]) + assert_type(result.rows, list[QueryRowDict]) + assert result.rows == [{"id": 1, "name": "alpha"}] - def _is_asyncpg(self) -> bool: - return "tortoise.backends.asyncpg" in type(self._db).__module__ - def _is_psycopg(self) -> bool: - return "tortoise.backends.psycopg" in type(self._db).__module__ +@pytest.mark.asyncio +async def test_execute_pypika_empty_result(query_api_db) -> None: + table = QueryModel.get_table() + query = Query.from_(table).select(table.id, table.name).where(table.name == "missing") - def _is_mysql(self) -> bool: - return "tortoise.backends.mysql" in type(self._db).__module__ + result: QueryResult[dict] = await execute_pypika(query) - def _is_odbc(self) -> bool: - return "tortoise.backends.odbc" in type(self._db).__module__ + assert result.rows == [] + assert result.rows_affected == 0 - def _select_query(self) -> QueryBuilder: - table = Tournament.get_table() - return Query.from_(table).select(table.id, table.name).orderby(table.id) - def _sql_context(self) -> SqlContext: - ctx = self._db.query_class.SQL_CONTEXT - if self._is_psycopg() and ctx.parameterizer is None: - ctx = ctx.copy(parameterizer=Parameterizer(placeholder_factory=lambda _: "%s")) - return ctx +@pytest.mark.asyncio +async def test_execute_pypika_invalid_schema_raises(query_api_db) -> None: + table = QueryModel.get_table() + query = Query.from_(table).select(table.name.as_("id"), table.name) - @test.requireCapability(dialect="sqlite") - async def test_rows_affected_select_sqlite(self) -> None: - result: QueryResult[dict] = await execute_pypika(self._select_query()) + with pytest.raises(ValidationError): + await execute_pypika(query, schema=QueryRow) - self.assertEqual(result.rows_affected, len(result.rows)) - self.assertEqual(len(result.rows), 2) - @test.requireCapability(dialect="postgres") - async def test_rows_affected_select_asyncpg(self) -> None: - if not self._is_asyncpg(): - self.skipTest("asyncpg only") +# ============================================================================= +# Tests for TestQueryApiRowsAffected (formerly test.TestCase) +# Uses db fixture with Tournament model +# ============================================================================= - result: QueryResult[dict] = await execute_pypika(self._select_query()) - self.assertEqual(result.rows_affected, len(result.rows)) - self.assertEqual(len(result.rows), 2) +def _is_asyncpg(db) -> bool: + return "tortoise.backends.asyncpg" in type(db.db()).__module__ - async def test_rows_affected_select_driver_rowcount(self) -> None: - if not (self._is_mysql() or self._is_odbc() or self._is_psycopg()): - self.skipTest("mysql/odbc/psycopg only") - query: QueryBuilder = self._select_query() - sql, params = query.get_parameterized_sql(self._sql_context()) - raw_rowcount, _ = await self._db.execute_query(sql, params) - result: QueryResult[dict] = await execute_pypika(query) +def _is_psycopg(db) -> bool: + return "tortoise.backends.psycopg" in type(db.db()).__module__ - expected = {raw_rowcount, len(result.rows)} - self.assertIn(result.rows_affected, expected) - async def test_rows_affected_update(self) -> None: - table = Tournament.get_table() - query = Query.update(table).set(table.name, "gamma").where(table.id == self._alpha_id) +def _is_mysql(db) -> bool: + return "tortoise.backends.mysql" in type(db.db()).__module__ - result: QueryResult[dict] = await execute_pypika(query) - self.assertEqual(result.rows, []) - self.assertEqual(result.rows_affected, 1) +def _is_odbc(db) -> bool: + return "tortoise.backends.odbc" in type(db.db()).__module__ - async def test_rows_affected_delete(self) -> None: - table = Tournament.get_table() - query = Query.from_(table).delete().where(table.id == self._beta_id) - result: QueryResult[dict] = await execute_pypika(query) +def _select_query() -> QueryBuilder: + table = Tournament.get_table() + return Query.from_(table).select(table.id, table.name).orderby(table.id) - self.assertEqual(result.rows, []) - self.assertEqual(result.rows_affected, 1) +def _sql_context(db) -> SqlContext: + ctx = db.db().query_class.SQL_CONTEXT + if _is_psycopg(db) and ctx.parameterizer is None: + ctx = ctx.copy(parameterizer=Parameterizer(placeholder_factory=lambda _: "%s")) + return ctx + + +@pytest_asyncio.fixture +async def rows_affected_setup(db): + """Fixture to set up Tournament data for rows_affected tests.""" + alpha = await Tournament.create(name="alpha") + beta = await Tournament.create(name="beta") + return {"alpha_id": alpha.id, "beta_id": beta.id, "db": db} + + +@test.requireCapability(dialect="sqlite") +@pytest.mark.asyncio +async def test_rows_affected_select_sqlite(rows_affected_setup) -> None: + result: QueryResult[dict] = await execute_pypika(_select_query()) + + assert result.rows_affected == len(result.rows) + assert len(result.rows) == 2 + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_rows_affected_select_asyncpg(rows_affected_setup) -> None: + db = rows_affected_setup["db"] + if not _is_asyncpg(db): + pytest.skip("asyncpg only") + + result: QueryResult[dict] = await execute_pypika(_select_query()) + + assert result.rows_affected == len(result.rows) + assert len(result.rows) == 2 + + +@pytest.mark.asyncio +async def test_rows_affected_select_driver_rowcount(rows_affected_setup) -> None: + db = rows_affected_setup["db"] + if not (_is_mysql(db) or _is_odbc(db) or _is_psycopg(db)): + pytest.skip("mysql/odbc/psycopg only") + + query: QueryBuilder = _select_query() + sql, params = query.get_parameterized_sql(_sql_context(db)) + raw_rowcount, _ = await db.db().execute_query(sql, params) + result: QueryResult[dict] = await execute_pypika(query) + + expected = {raw_rowcount, len(result.rows)} + assert result.rows_affected in expected -class TestQueryApiConnectionSelection(SimpleTestCase): - async def test_execute_pypika_explicit_connection_with_multiple_configured(self) -> None: - connections._db_config = {"first": {}, "second": {}} - query = Query.from_(Table("dummy")).select("*") - class DummyClient: - query_class = type("QueryClass", (), {"SQL_CONTEXT": None}) +@pytest.mark.asyncio +async def test_rows_affected_update(rows_affected_setup) -> None: + alpha_id = rows_affected_setup["alpha_id"] + table = Tournament.get_table() + query = Query.update(table).set(table.name, "gamma").where(table.id == alpha_id) - async def execute_query_dict_with_affected(self, query, values=None): - return [], 0 + result: QueryResult[dict] = await execute_pypika(query) - token = connections.set("second", DummyClient()) # type: ignore[arg-type] + assert result.rows == [] + assert result.rows_affected == 1 + + +@pytest.mark.asyncio +async def test_rows_affected_delete(rows_affected_setup) -> None: + beta_id = rows_affected_setup["beta_id"] + table = Tournament.get_table() + query = Query.from_(table).delete().where(table.id == beta_id) + + result: QueryResult[dict] = await execute_pypika(query) + + assert result.rows == [] + assert result.rows_affected == 1 + + +# ============================================================================= +# Tests for TestQueryApiConnectionSelection (formerly SimpleTestCase) +# Tests connection selection behavior +# ============================================================================= + + +@pytest.mark.asyncio +async def test_execute_pypika_explicit_connection_with_multiple_configured() -> None: + """Test execute_pypika with explicit connection when multiple are configured.""" + + class DummyClient: + query_class = type("QueryClass", (), {"SQL_CONTEXT": None}) + + async def execute_query_dict_with_affected(self, query, values=None): + return [], 0 + + async with TortoiseContext() as ctx: + await ctx.init( + config={ + "connections": { + "first": "sqlite://:memory:", + "second": "sqlite://:memory:", + }, + "apps": { + "models": {"models": [__name__], "default_connection": "first"}, + }, + } + ) + await ctx.generate_schemas() + + query = Query.from_(Table("dummy")).select("*") + + token = ctx.connections.set("second", DummyClient()) # type: ignore[arg-type] try: result: QueryResult[dict] = await execute_pypika( - query, using_db=connections.get("second") + query, using_db=ctx.connections.get("second") ) finally: - connections.reset(token) + ctx.connections.reset(token) + + assert result.rows_affected == 0 + + +@pytest_asyncio.fixture +async def multi_db(): + """Fixture that sets up multiple databases for testing.""" + from tortoise.context import TortoiseContext + + ctx = TortoiseContext() + async with ctx: + await ctx.init( + config={ + "connections": { + "first": "sqlite://:memory:", + "second": "sqlite://:memory:", + }, + "apps": { + "models": {"models": [__name__], "default_connection": "first"}, + }, + } + ) + await ctx.generate_schemas() + yield ctx - self.assertEqual(result.rows_affected, 0) - async def test_execute_pypika_requires_connection_with_multiple_configured(self) -> None: - connections._db_config = {"first": {}, "second": {}} - query = Query.from_(Table("dummy")).select("*") +@pytest.mark.asyncio +async def test_execute_pypika_requires_connection_with_multiple_configured(multi_db) -> None: + query = Query.from_(Table("dummy")).select("*") - with self.assertRaises(ParamsError) as ctx: - await execute_pypika(query) + with pytest.raises(ParamsError) as exc_info: + await execute_pypika(query) - self.assertIn("multiple databases", str(ctx.exception)) + assert "multiple databases" in str(exc_info.value) diff --git a/tests/test_queryset.py b/tests/test_queryset.py index db3683df9..513979035 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -1,3 +1,6 @@ +import pytest +import pytest_asyncio + from tests.testmodels import ( Author, Book, @@ -11,7 +14,7 @@ ) from tortoise import connections from tortoise.backends.psycopg.client import PsycopgClient -from tortoise.contrib import test +from tortoise.contrib.test import requireCapability from tortoise.contrib.test.condition import NotEQ from tortoise.exceptions import ( DoesNotExist, @@ -28,862 +31,873 @@ # TODO: .filter(intnum_null=None) does not work as expected -class TestQueryset(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - # Build large dataset - self.intfields = [await IntFields.create(intnum=val) for val in range(10, 100, 3)] - self.db = connections.get("models") - - async def test_all_count(self): - self.assertEqual(await IntFields.all().count(), 30) - self.assertEqual(await IntFields.filter(intnum_null=80).count(), 0) - - async def test_exists(self): - ret = await IntFields.filter(intnum=0).exists() - self.assertFalse(ret) - - ret = await IntFields.filter(intnum=10).exists() - self.assertTrue(ret) - - ret = await IntFields.filter(intnum__gt=10).exists() - self.assertTrue(ret) - - ret = await IntFields.filter(intnum__lt=10).exists() - self.assertFalse(ret) - - async def test_limit_count(self): - self.assertEqual(await IntFields.all().limit(10).count(), 10) - - async def test_limit_negative(self): - with self.assertRaisesRegex(ParamsError, "Limit should be non-negative number"): - await IntFields.all().limit(-10) - - @test.requireCapability(dialect="sqlite") - async def test_limit_zero(self): - sql = IntFields.all().only("id").limit(0).sql() - self.assertEqual( - sql, - 'SELECT "id" "id" FROM "intfields" LIMIT ?', - ) - - async def test_offset_count(self): - self.assertEqual(await IntFields.all().offset(10).count(), 20) - - async def test_offset_negative(self): - with self.assertRaisesRegex(ParamsError, "Offset should be non-negative number"): - await IntFields.all().offset(-10) - - async def test_slicing_start_and_stop(self) -> None: - sliced_queryset = IntFields.all().order_by("intnum")[1:5] - manually_sliced_queryset = IntFields.all().order_by("intnum").offset(1).limit(4) - self.assertSequenceEqual(await sliced_queryset, await manually_sliced_queryset) - - async def test_slicing_only_limit(self) -> None: - sliced_queryset = IntFields.all().order_by("intnum")[:5] - manually_sliced_queryset = IntFields.all().order_by("intnum").limit(5) - self.assertSequenceEqual(await sliced_queryset, await manually_sliced_queryset) - - async def test_slicing_only_offset(self) -> None: - sliced_queryset = IntFields.all().order_by("intnum")[5:] - manually_sliced_queryset = IntFields.all().order_by("intnum").offset(5) - self.assertSequenceEqual(await sliced_queryset, await manually_sliced_queryset) - - async def test_slicing_count(self) -> None: - queryset = IntFields.all().order_by("intnum")[1:5] - self.assertEqual(await queryset.count(), 4) - - def test_slicing_negative_values(self) -> None: - with self.assertRaisesRegex( - expected_exception=ParamsError, - expected_regex="Slice start should be non-negative number or None.", - ): - _ = IntFields.all()[-1:] - - with self.assertRaisesRegex( - expected_exception=ParamsError, - expected_regex="Slice stop should be non-negative number greater that slice start, " - "or None.", - ): - _ = IntFields.all()[:-1] - - def test_slicing_stop_before_start(self) -> None: - with self.assertRaisesRegex( - expected_exception=ParamsError, - expected_regex="Slice stop should be non-negative number greater that slice start, " - "or None.", - ): - _ = IntFields.all()[2:1] - - async def test_slicing_steps(self) -> None: - sliced_queryset = IntFields.all().order_by("intnum")[::1] - manually_sliced_queryset = IntFields.all().order_by("intnum") - self.assertSequenceEqual(await sliced_queryset, await manually_sliced_queryset) - - with self.assertRaisesRegex( - expected_exception=ParamsError, - expected_regex="Slice steps should be 1 or None.", - ): - _ = IntFields.all()[::2] - - async def test_join_count(self): - tour = await Tournament.create(name="moo") - await MinRelation.create(tournament=tour) - - self.assertEqual(await MinRelation.all().count(), 1) - self.assertEqual(await MinRelation.filter(tournament__id=tour.id).count(), 1) - - async def test_modify_dataset(self): - # Modify dataset - rows_affected = await IntFields.filter(intnum__gte=70).update(intnum_null=80) - self.assertEqual(rows_affected, 10) - self.assertEqual(await IntFields.filter(intnum_null=80).count(), 10) - self.assertEqual(await IntFields.filter(intnum_null__isnull=True).count(), 20) - await IntFields.filter(intnum_null__isnull=True).update(intnum_null=-1) - self.assertEqual(await IntFields.filter(intnum_null=None).count(), 0) - self.assertEqual(await IntFields.filter(intnum_null=-1).count(), 20) - - async def test_distinct(self): - # Test distinct - await IntFields.filter(intnum__gte=70).update(intnum_null=80) - await IntFields.filter(intnum_null__isnull=True).update(intnum_null=-1) - - self.assertEqual( - await IntFields.all() - .order_by("intnum_null") - .distinct() - .values_list("intnum_null", flat=True), - [-1, 80], - ) - - self.assertEqual( - await IntFields.all().order_by("intnum_null").distinct().values("intnum_null"), - [{"intnum_null": -1}, {"intnum_null": 80}], - ) - - async def test_limit_offset_values_list(self): - # Test limit/offset/ordering values_list - self.assertEqual( - await IntFields.all().order_by("intnum").limit(10).values_list("intnum", flat=True), - [10, 13, 16, 19, 22, 25, 28, 31, 34, 37], - ) - - self.assertEqual( - await IntFields.all() - .order_by("intnum") - .limit(10) - .offset(10) - .values_list("intnum", flat=True), - [40, 43, 46, 49, 52, 55, 58, 61, 64, 67], - ) - - self.assertEqual( - await IntFields.all() - .order_by("intnum") - .limit(10) - .offset(20) - .values_list("intnum", flat=True), - [70, 73, 76, 79, 82, 85, 88, 91, 94, 97], - ) - - self.assertEqual( - await IntFields.all() - .order_by("intnum") - .limit(10) - .offset(30) - .values_list("intnum", flat=True), - [], - ) - - self.assertEqual( - await IntFields.all().order_by("-intnum").limit(10).values_list("intnum", flat=True), - [97, 94, 91, 88, 85, 82, 79, 76, 73, 70], - ) - - self.assertEqual( - await IntFields.all() - .order_by("intnum") - .limit(10) - .filter(intnum__gte=40) - .values_list("intnum", flat=True), - [40, 43, 46, 49, 52, 55, 58, 61, 64, 67], - ) - - async def test_limit_offset_values(self): - # Test limit/offset/ordering values - self.assertEqual( - await IntFields.all().order_by("intnum").limit(5).values("intnum"), - [{"intnum": 10}, {"intnum": 13}, {"intnum": 16}, {"intnum": 19}, {"intnum": 22}], - ) - - self.assertEqual( - await IntFields.all().order_by("intnum").limit(5).offset(10).values("intnum"), - [{"intnum": 40}, {"intnum": 43}, {"intnum": 46}, {"intnum": 49}, {"intnum": 52}], - ) - - self.assertEqual( - await IntFields.all().order_by("intnum").limit(5).offset(30).values("intnum"), [] - ) - - self.assertEqual( - await IntFields.all().order_by("-intnum").limit(5).values("intnum"), - [{"intnum": 97}, {"intnum": 94}, {"intnum": 91}, {"intnum": 88}, {"intnum": 85}], - ) - - self.assertEqual( - await IntFields.all() - .order_by("intnum") - .limit(5) - .filter(intnum__gte=40) - .values("intnum"), - [{"intnum": 40}, {"intnum": 43}, {"intnum": 46}, {"intnum": 49}, {"intnum": 52}], - ) - - async def test_in_bulk(self): - id_list = [item.pk for item in await IntFields.all().only("id").limit(2)] - ret = await IntFields.in_bulk(id_list=id_list) - self.assertEqual(list(ret.keys()), id_list) - - async def test_first(self): - # Test first - self.assertEqual( - (await IntFields.all().order_by("intnum").filter(intnum__gte=40).first()).intnum, 40 - ) - self.assertEqual( - (await IntFields.all().order_by("intnum").filter(intnum__gte=40).first().values())[ - "intnum" - ], - 40, - ) - self.assertEqual( - (await IntFields.all().order_by("intnum").filter(intnum__gte=40).first().values_list())[ - 1 - ], - 40, - ) - - self.assertEqual( - await IntFields.all().order_by("intnum").filter(intnum__gte=400).first(), None - ) - self.assertEqual( - await IntFields.all().order_by("intnum").filter(intnum__gte=400).first().values(), None - ) - self.assertEqual( - await IntFields.all().order_by("intnum").filter(intnum__gte=400).first().values_list(), - None, - ) - - async def test_last(self): - self.assertEqual( - (await IntFields.all().order_by("intnum").filter(intnum__gte=40).last()).intnum, 97 - ) - self.assertEqual( - (await IntFields.all().order_by("intnum").filter(intnum__gte=40).last().values())[ - "intnum" - ], - 97, - ) - self.assertEqual( - (await IntFields.all().order_by("intnum").filter(intnum__gte=40).last().values_list())[ - 1 - ], - 97, - ) - - self.assertEqual( - await IntFields.all().order_by("intnum").filter(intnum__gte=400).last(), None - ) - self.assertEqual( - await IntFields.all().order_by("intnum").filter(intnum__gte=400).last().values(), None - ) - self.assertEqual( - await IntFields.all().order_by("intnum").filter(intnum__gte=400).last().values_list(), - None, - ) - self.assertEqual((await IntFields.all().filter(intnum__gte=40).last()).intnum, 97) - - async def test_latest(self): - self.assertEqual((await IntFields.all().latest("intnum")).intnum, 97) - self.assertEqual( - (await IntFields.all().order_by("-intnum").first()).intnum, - (await IntFields.all().latest("intnum")).intnum, - ) - self.assertEqual((await IntFields.all().filter(intnum__gte=40).latest("intnum")).intnum, 97) - self.assertEqual( - (await IntFields.all().filter(intnum__gte=40).latest("intnum").values())["intnum"], - 97, - ) - self.assertEqual( - (await IntFields.all().filter(intnum__gte=40).latest("intnum").values_list())[1], - 97, - ) - - self.assertEqual(await IntFields.all().filter(intnum__gte=400).latest("intnum"), None) - self.assertEqual( - await IntFields.all().filter(intnum__gte=400).latest("intnum").values(), None - ) - self.assertEqual( - await IntFields.all().filter(intnum__gte=400).latest("intnum").values_list(), - None, - ) - - with self.assertRaises(FieldError): - await IntFields.all().latest() - - with self.assertRaises(FieldError): - await IntFields.all().latest("some_unkown_field") - - async def test_earliest(self): - self.assertEqual((await IntFields.all().earliest("intnum")).intnum, 10) - self.assertEqual( - (await IntFields.all().order_by("intnum").first()).intnum, - (await IntFields.all().earliest("intnum")).intnum, - ) - self.assertEqual( - (await IntFields.all().filter(intnum__gte=40).earliest("intnum")).intnum, 40 - ) - self.assertEqual( - (await IntFields.all().filter(intnum__gte=40).earliest("intnum").values())["intnum"], - 40, - ) - self.assertEqual( - (await IntFields.all().filter(intnum__gte=40).earliest("intnum").values_list())[1], - 40, - ) - - self.assertEqual(await IntFields.all().filter(intnum__gte=400).earliest("intnum"), None) - self.assertEqual( - await IntFields.all().filter(intnum__gte=400).earliest("intnum").values(), None - ) - self.assertEqual( - await IntFields.all().filter(intnum__gte=400).earliest("intnum").values_list(), - None, - ) - - with self.assertRaises(FieldError): - await IntFields.all().earliest() - - with self.assertRaises(FieldError): - await IntFields.all().earliest("some_unkown_field") - - async def test_get_or_none(self): - self.assertEqual((await IntFields.all().get_or_none(intnum=40)).intnum, 40) - self.assertEqual((await IntFields.all().get_or_none(intnum=40).values())["intnum"], 40) - self.assertEqual((await IntFields.all().get_or_none(intnum=40).values_list())[1], 40) - - self.assertEqual( - await IntFields.all().order_by("intnum").get_or_none(intnum__gte=400), None - ) - - self.assertEqual( - await IntFields.all().order_by("intnum").get_or_none(intnum__gte=400).values(), None - ) - - self.assertEqual( - await IntFields.all().order_by("intnum").get_or_none(intnum__gte=400).values_list(), - None, - ) - - with self.assertRaises(MultipleObjectsReturned): - await IntFields.all().order_by("intnum").get_or_none(intnum__gte=40) - - with self.assertRaises(MultipleObjectsReturned): - await IntFields.all().order_by("intnum").get_or_none(intnum__gte=40).values() - - with self.assertRaises(MultipleObjectsReturned): - await IntFields.all().order_by("intnum").get_or_none(intnum__gte=40).values_list() - - async def test_get(self): - await IntFields.filter(intnum__gte=70).update(intnum_null=80) - - # Test get - self.assertEqual((await IntFields.all().get(intnum=40)).intnum, 40) - self.assertEqual((await IntFields.all().get(intnum=40).values())["intnum"], 40) - self.assertEqual((await IntFields.all().get(intnum=40).values_list())[1], 40) - - self.assertEqual((await IntFields.all().all().all().all().all().get(intnum=40)).intnum, 40) - self.assertEqual( - (await IntFields.all().all().all().all().all().get(intnum=40).values())["intnum"], 40 - ) - self.assertEqual( - (await IntFields.all().all().all().all().all().get(intnum=40).values_list())[1], 40 - ) - - self.assertEqual((await IntFields.get(intnum=40)).intnum, 40) - self.assertEqual((await IntFields.get(intnum=40).values())["intnum"], 40) - self.assertEqual((await IntFields.get(intnum=40).values_list())[1], 40) - - with self.assertRaises(DoesNotExist): - await IntFields.all().get(intnum=41) - - with self.assertRaises(DoesNotExist): - await IntFields.all().get(intnum=41).values() - - with self.assertRaises(DoesNotExist): - await IntFields.all().get(intnum=41).values_list() - - with self.assertRaises(DoesNotExist): - await IntFields.get(intnum=41) - - with self.assertRaises(DoesNotExist): - await IntFields.get(intnum=41).values() - - with self.assertRaises(DoesNotExist): - await IntFields.get(intnum=41).values_list() - - with self.assertRaises(MultipleObjectsReturned): - await IntFields.all().get(intnum_null=80) - - with self.assertRaises(MultipleObjectsReturned): - await IntFields.all().get(intnum_null=80).values() - - with self.assertRaises(MultipleObjectsReturned): - await IntFields.all().get(intnum_null=80).values_list() - - with self.assertRaises(MultipleObjectsReturned): - await IntFields.get(intnum_null=80) - - with self.assertRaises(MultipleObjectsReturned): - await IntFields.get(intnum_null=80).values() - - with self.assertRaises(MultipleObjectsReturned): - await IntFields.get(intnum_null=80).values_list() - - async def test_delete(self): - # Test delete - await (await IntFields.get(intnum=40)).delete() - - with self.assertRaises(DoesNotExist): - await IntFields.get(intnum=40) +@pytest_asyncio.fixture +async def intfields_data(db): + """Build large dataset for IntFields tests.""" + intfields = [await IntFields.create(intnum=val) for val in range(10, 100, 3)] + return intfields - self.assertEqual(await IntFields.all().count(), 29) - rows_affected = ( - await IntFields.all().order_by("intnum").limit(10).filter(intnum__gte=70).delete() - ) - self.assertEqual(rows_affected, 10) +@pytest.mark.asyncio +async def test_all_count(db, intfields_data): + assert await IntFields.all().count() == 30 + assert await IntFields.filter(intnum_null=80).count() == 0 - self.assertEqual(await IntFields.all().count(), 19) - @test.requireCapability(support_update_limit_order_by=True) - async def test_delete_limit(self): - await IntFields.all().limit(1).delete() - self.assertEqual(await IntFields.all().count(), 29) - - @test.requireCapability(support_update_limit_order_by=True) - async def test_delete_limit_order_by(self): - await IntFields.all().limit(1).order_by("-id").delete() - self.assertEqual(await IntFields.all().count(), 29) - with self.assertRaises(DoesNotExist): - await IntFields.get(intnum=97) - - async def test_async_iter(self): - counter = 0 - async for _ in IntFields.all(): - counter += 1 - - self.assertEqual(await IntFields.all().count(), counter) - - async def test_update_basic(self): - obj0 = await IntFields.create(intnum=2147483647) - await IntFields.filter(id=obj0.id).update(intnum=2147483646) - obj = await IntFields.get(id=obj0.id) - self.assertEqual(obj.intnum, 2147483646) - self.assertEqual(obj.intnum_null, None) - - async def test_update_f_expression(self): - obj0 = await IntFields.create(intnum=2147483647) - await IntFields.filter(id=obj0.id).update(intnum=F("intnum") - 1) - obj = await IntFields.get(id=obj0.id) - self.assertEqual(obj.intnum, 2147483646) - - async def test_update_badparam(self): - obj0 = await IntFields.create(intnum=2147483647) - with self.assertRaisesRegex(FieldError, "Unknown keyword argument"): - await IntFields.filter(id=obj0.id).update(badparam=1) - - async def test_update_pk(self): - obj0 = await IntFields.create(intnum=2147483647) - with self.assertRaisesRegex(IntegrityError, "is PK and can not be updated"): - await IntFields.filter(id=obj0.id).update(id=1) - - async def test_update_virtual(self): - tour = await Tournament.create(name="moo") - obj0 = await MinRelation.create(tournament=tour) - with self.assertRaisesRegex(FieldError, "is virtual and can not be updated"): - await MinRelation.filter(id=obj0.id).update(participants=[]) - - async def test_bad_ordering(self): - with self.assertRaisesRegex(FieldError, "Unknown field moo1fip for model IntFields"): - await IntFields.all().order_by("moo1fip") - - async def test_duplicate_values(self): - with self.assertRaisesRegex(FieldError, "Duplicate key intnum"): - await IntFields.all().values("intnum", "intnum") - - async def test_duplicate_values_list(self): - await IntFields.all().values_list("intnum", "intnum") - - async def test_duplicate_values_kw(self): - with self.assertRaisesRegex(FieldError, "Duplicate key intnum"): - await IntFields.all().values("intnum", intnum="intnum_null") - - async def test_duplicate_values_kw_badmap(self): - with self.assertRaisesRegex(FieldError, 'Unknown field "intnum2" for model "IntFields"'): - await IntFields.all().values(intnum="intnum2") - - async def test_bad_values(self): - with self.assertRaisesRegex(FieldError, 'Unknown field "int2num" for model "IntFields"'): - await IntFields.all().values("int2num") - - async def test_bad_values_list(self): - with self.assertRaisesRegex(FieldError, 'Unknown field "int2num" for model "IntFields"'): - await IntFields.all().values_list("int2num") - - async def test_many_flat_values_list(self): - with self.assertRaisesRegex( - TypeError, "You can flat value_list only if contains one field" - ): - await IntFields.all().values_list("intnum", "intnum_null", flat=True) - - async def test_all_flat_values_list(self): - with self.assertRaisesRegex( - TypeError, "You can flat value_list only if contains one field" - ): - await IntFields.all().values_list(flat=True) - - async def test_all_values_list(self): - data = await IntFields.all().order_by("id").values_list() - self.assertEqual(data[2], (self.intfields[2].id, 16, None)) - - async def test_all_values(self): - data = await IntFields.all().order_by("id").values() - self.assertEqual(data[2], {"id": self.intfields[2].id, "intnum": 16, "intnum_null": None}) - - async def test_order_by_bad_value(self): - with self.assertRaisesRegex(FieldError, "Unknown field badid for model IntFields"): - await IntFields.all().order_by("badid").values_list() - - async def test_annotate_order_expression(self): - data = ( - await IntFields.annotate(idp=F("id") + 1) - .order_by("-idp") - .first() - .values_list("id", "idp") - ) - self.assertEqual(data[0] + 1, data[1]) - - async def test_annotate_order_rawsql(self): - qs = IntFields.annotate(idp=RawSQL("id+1")).order_by("-idp") - data = await qs.first().values_list("id", "idp") - self.assertEqual(data[0] + 1, data[1]) - - async def test_annotate_expression_filter(self): - count = await IntFields.annotate(intnum1=F("intnum") + 1).filter(intnum1__gt=30).count() - self.assertEqual(count, 23) - - async def test_get_raw_sql(self): - sql = IntFields.all().sql() - self.assertRegex(sql, r"^SELECT.+FROM.+") - - @test.requireCapability(support_index_hint=True) - async def test_force_index(self): - sql = IntFields.filter(pk=1).only("id").force_index("index_name").sql() - self.assertEqual( - sql, - "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", - ) - - sql_again = IntFields.filter(pk=1).only("id").force_index("index_name").sql() - self.assertEqual( - sql_again, - "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", - ) - - @test.requireCapability(support_index_hint=True) - async def test_force_index_available_in_more_query(self): - sql_ValuesQuery = IntFields.filter(pk=1).force_index("index_name").values("id").sql() - self.assertEqual( - sql_ValuesQuery, - "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", - ) - - sql_ValuesListQuery = ( - IntFields.filter(pk=1).force_index("index_name").values_list("id").sql() - ) - self.assertEqual( - sql_ValuesListQuery, - "SELECT `id` `0` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", - ) - - sql_CountQuery = IntFields.filter(pk=1).force_index("index_name").count().sql() - self.assertEqual( - sql_CountQuery, - "SELECT COUNT(*) FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", - ) - - sql_ExistsQuery = IntFields.filter(pk=1).force_index("index_name").exists().sql() - self.assertEqual( - sql_ExistsQuery, - "SELECT 1 FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s LIMIT %s", - ) - - @test.requireCapability(support_index_hint=True) - async def test_use_index(self): - sql = IntFields.filter(pk=1).only("id").use_index("index_name").sql() - self.assertEqual( - sql, - "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", - ) - - sql_again = IntFields.filter(pk=1).only("id").use_index("index_name").sql() - self.assertEqual( - sql_again, - "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", - ) - - @test.requireCapability(support_index_hint=True) - async def test_use_index_available_in_more_query(self): - sql_ValuesQuery = IntFields.filter(pk=1).use_index("index_name").values("id").sql() - self.assertEqual( - sql_ValuesQuery, - "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", - ) - - sql_ValuesListQuery = IntFields.filter(pk=1).use_index("index_name").values_list("id").sql() - self.assertEqual( - sql_ValuesListQuery, - "SELECT `id` `0` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", - ) - - sql_CountQuery = IntFields.filter(pk=1).use_index("index_name").count().sql() - self.assertEqual( - sql_CountQuery, - "SELECT COUNT(*) FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", - ) - - sql_ExistsQuery = IntFields.filter(pk=1).use_index("index_name").exists().sql() - self.assertEqual( - sql_ExistsQuery, - "SELECT 1 FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s LIMIT %s", - ) - - @test.requireCapability(support_for_update=True) - async def test_select_for_update(self): - sql1 = IntFields.filter(pk=1).only("id").select_for_update().sql() - sql2 = IntFields.filter(pk=1).only("id").select_for_update(nowait=True).sql() - sql3 = IntFields.filter(pk=1).only("id").select_for_update(skip_locked=True).sql() - sql4 = IntFields.filter(pk=1).only("id").select_for_update(of=("intfields",)).sql() - sql5 = IntFields.filter(pk=1).only("id").select_for_update(no_key=True).sql() - - dialect = self.db.schema_generator.DIALECT - if dialect == "postgres": - if isinstance(self.db, PsycopgClient): - self.assertEqual( - sql1, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE', - ) - self.assertEqual( - sql2, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE NOWAIT', - ) - self.assertEqual( - sql3, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE SKIP LOCKED', - ) - self.assertEqual( - sql4, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE OF "intfields"', - ) - self.assertEqual( - sql5, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR NO KEY UPDATE', - ) - else: - self.assertEqual( - sql1, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE', - ) - self.assertEqual( - sql2, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE NOWAIT', - ) - self.assertEqual( - sql3, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE SKIP LOCKED', - ) - self.assertEqual( - sql4, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE OF "intfields"', - ) - self.assertEqual( - sql5, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR NO KEY UPDATE', - ) - elif dialect == "mysql": - self.assertEqual( - sql1, - "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE", - ) - self.assertEqual( - sql2, - "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE NOWAIT", - ) - self.assertEqual( - sql3, - "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE SKIP LOCKED", - ) - self.assertEqual( - sql4, - "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE OF `intfields`", +@pytest.mark.asyncio +async def test_exists(db, intfields_data): + ret = await IntFields.filter(intnum=0).exists() + assert not ret + + ret = await IntFields.filter(intnum=10).exists() + assert ret + + ret = await IntFields.filter(intnum__gt=10).exists() + assert ret + + ret = await IntFields.filter(intnum__lt=10).exists() + assert not ret + + +@pytest.mark.asyncio +async def test_limit_count(db, intfields_data): + assert await IntFields.all().limit(10).count() == 10 + + +@pytest.mark.asyncio +async def test_limit_negative(db, intfields_data): + with pytest.raises(ParamsError, match="Limit should be non-negative number"): + await IntFields.all().limit(-10) + + +@requireCapability(dialect="sqlite") +@pytest.mark.asyncio +async def test_limit_zero(db, intfields_data): + sql = IntFields.all().only("id").limit(0).sql() + assert sql == 'SELECT "id" "id" FROM "intfields" LIMIT ?' + + +@pytest.mark.asyncio +async def test_offset_count(db, intfields_data): + assert await IntFields.all().offset(10).count() == 20 + + +@pytest.mark.asyncio +async def test_offset_negative(db, intfields_data): + with pytest.raises(ParamsError, match="Offset should be non-negative number"): + await IntFields.all().offset(-10) + + +@pytest.mark.asyncio +async def test_slicing_start_and_stop(db, intfields_data): + sliced_queryset = IntFields.all().order_by("intnum")[1:5] + manually_sliced_queryset = IntFields.all().order_by("intnum").offset(1).limit(4) + assert list(await sliced_queryset) == list(await manually_sliced_queryset) + + +@pytest.mark.asyncio +async def test_slicing_only_limit(db, intfields_data): + sliced_queryset = IntFields.all().order_by("intnum")[:5] + manually_sliced_queryset = IntFields.all().order_by("intnum").limit(5) + assert list(await sliced_queryset) == list(await manually_sliced_queryset) + + +@pytest.mark.asyncio +async def test_slicing_only_offset(db, intfields_data): + sliced_queryset = IntFields.all().order_by("intnum")[5:] + manually_sliced_queryset = IntFields.all().order_by("intnum").offset(5) + assert list(await sliced_queryset) == list(await manually_sliced_queryset) + + +@pytest.mark.asyncio +async def test_slicing_count(db, intfields_data): + queryset = IntFields.all().order_by("intnum")[1:5] + assert await queryset.count() == 4 + + +def test_slicing_negative_values(db): + with pytest.raises( + ParamsError, + match="Slice start should be non-negative number or None.", + ): + _ = IntFields.all()[-1:] + + with pytest.raises( + ParamsError, + match="Slice stop should be non-negative number greater that slice start, or None.", + ): + _ = IntFields.all()[:-1] + + +def test_slicing_stop_before_start(db): + with pytest.raises( + ParamsError, + match="Slice stop should be non-negative number greater that slice start, or None.", + ): + _ = IntFields.all()[2:1] + + +@pytest.mark.asyncio +async def test_slicing_steps(db, intfields_data): + sliced_queryset = IntFields.all().order_by("intnum")[::1] + manually_sliced_queryset = IntFields.all().order_by("intnum") + assert list(await sliced_queryset) == list(await manually_sliced_queryset) + + with pytest.raises( + ParamsError, + match="Slice steps should be 1 or None.", + ): + _ = IntFields.all()[::2] + + +@pytest.mark.asyncio +async def test_join_count(db): + tour = await Tournament.create(name="moo") + await MinRelation.create(tournament=tour) + + assert await MinRelation.all().count() == 1 + assert await MinRelation.filter(tournament__id=tour.id).count() == 1 + + +@pytest.mark.asyncio +async def test_modify_dataset(db, intfields_data): + # Modify dataset + rows_affected = await IntFields.filter(intnum__gte=70).update(intnum_null=80) + assert rows_affected == 10 + assert await IntFields.filter(intnum_null=80).count() == 10 + assert await IntFields.filter(intnum_null__isnull=True).count() == 20 + await IntFields.filter(intnum_null__isnull=True).update(intnum_null=-1) + assert await IntFields.filter(intnum_null=None).count() == 0 + assert await IntFields.filter(intnum_null=-1).count() == 20 + + +@pytest.mark.asyncio +async def test_distinct(db, intfields_data): + # Test distinct + await IntFields.filter(intnum__gte=70).update(intnum_null=80) + await IntFields.filter(intnum_null__isnull=True).update(intnum_null=-1) + + assert await IntFields.all().order_by("intnum_null").distinct().values_list( + "intnum_null", flat=True + ) == [-1, 80] + + assert await IntFields.all().order_by("intnum_null").distinct().values("intnum_null") == [ + {"intnum_null": -1}, + {"intnum_null": 80}, + ] + + +@pytest.mark.asyncio +async def test_limit_offset_values_list(db, intfields_data): + # Test limit/offset/ordering values_list + assert await IntFields.all().order_by("intnum").limit(10).values_list("intnum", flat=True) == [ + 10, + 13, + 16, + 19, + 22, + 25, + 28, + 31, + 34, + 37, + ] + + assert await IntFields.all().order_by("intnum").limit(10).offset(10).values_list( + "intnum", flat=True + ) == [40, 43, 46, 49, 52, 55, 58, 61, 64, 67] + + assert await IntFields.all().order_by("intnum").limit(10).offset(20).values_list( + "intnum", flat=True + ) == [70, 73, 76, 79, 82, 85, 88, 91, 94, 97] + + assert ( + await IntFields.all() + .order_by("intnum") + .limit(10) + .offset(30) + .values_list("intnum", flat=True) + == [] + ) + + assert await IntFields.all().order_by("-intnum").limit(10).values_list("intnum", flat=True) == [ + 97, + 94, + 91, + 88, + 85, + 82, + 79, + 76, + 73, + 70, + ] + + assert await IntFields.all().order_by("intnum").limit(10).filter(intnum__gte=40).values_list( + "intnum", flat=True + ) == [40, 43, 46, 49, 52, 55, 58, 61, 64, 67] + + +@pytest.mark.asyncio +async def test_limit_offset_values(db, intfields_data): + # Test limit/offset/ordering values + assert await IntFields.all().order_by("intnum").limit(5).values("intnum") == [ + {"intnum": 10}, + {"intnum": 13}, + {"intnum": 16}, + {"intnum": 19}, + {"intnum": 22}, + ] + + assert await IntFields.all().order_by("intnum").limit(5).offset(10).values("intnum") == [ + {"intnum": 40}, + {"intnum": 43}, + {"intnum": 46}, + {"intnum": 49}, + {"intnum": 52}, + ] + + assert await IntFields.all().order_by("intnum").limit(5).offset(30).values("intnum") == [] + + assert await IntFields.all().order_by("-intnum").limit(5).values("intnum") == [ + {"intnum": 97}, + {"intnum": 94}, + {"intnum": 91}, + {"intnum": 88}, + {"intnum": 85}, + ] + + assert await IntFields.all().order_by("intnum").limit(5).filter(intnum__gte=40).values( + "intnum" + ) == [ + {"intnum": 40}, + {"intnum": 43}, + {"intnum": 46}, + {"intnum": 49}, + {"intnum": 52}, + ] + + +@pytest.mark.asyncio +async def test_in_bulk(db, intfields_data): + id_list = [item.pk for item in await IntFields.all().only("id").limit(2)] + ret = await IntFields.in_bulk(id_list=id_list) + assert list(ret.keys()) == id_list + + +@pytest.mark.asyncio +async def test_first(db, intfields_data): + # Test first + assert (await IntFields.all().order_by("intnum").filter(intnum__gte=40).first()).intnum == 40 + assert (await IntFields.all().order_by("intnum").filter(intnum__gte=40).first().values())[ + "intnum" + ] == 40 + assert (await IntFields.all().order_by("intnum").filter(intnum__gte=40).first().values_list())[ + 1 + ] == 40 + + assert await IntFields.all().order_by("intnum").filter(intnum__gte=400).first() is None + assert await IntFields.all().order_by("intnum").filter(intnum__gte=400).first().values() is None + assert ( + await IntFields.all().order_by("intnum").filter(intnum__gte=400).first().values_list() + is None + ) + + +@pytest.mark.asyncio +async def test_last(db, intfields_data): + assert (await IntFields.all().order_by("intnum").filter(intnum__gte=40).last()).intnum == 97 + assert (await IntFields.all().order_by("intnum").filter(intnum__gte=40).last().values())[ + "intnum" + ] == 97 + assert (await IntFields.all().order_by("intnum").filter(intnum__gte=40).last().values_list())[ + 1 + ] == 97 + + assert await IntFields.all().order_by("intnum").filter(intnum__gte=400).last() is None + assert await IntFields.all().order_by("intnum").filter(intnum__gte=400).last().values() is None + assert ( + await IntFields.all().order_by("intnum").filter(intnum__gte=400).last().values_list() + is None + ) + assert (await IntFields.all().filter(intnum__gte=40).last()).intnum == 97 + + +@pytest.mark.asyncio +async def test_latest(db, intfields_data): + assert (await IntFields.all().latest("intnum")).intnum == 97 + assert (await IntFields.all().order_by("-intnum").first()).intnum == ( + await IntFields.all().latest("intnum") + ).intnum + assert (await IntFields.all().filter(intnum__gte=40).latest("intnum")).intnum == 97 + assert (await IntFields.all().filter(intnum__gte=40).latest("intnum").values())["intnum"] == 97 + assert (await IntFields.all().filter(intnum__gte=40).latest("intnum").values_list())[1] == 97 + + assert await IntFields.all().filter(intnum__gte=400).latest("intnum") is None + assert await IntFields.all().filter(intnum__gte=400).latest("intnum").values() is None + assert await IntFields.all().filter(intnum__gte=400).latest("intnum").values_list() is None + + with pytest.raises(FieldError): + await IntFields.all().latest() + + with pytest.raises(FieldError): + await IntFields.all().latest("some_unkown_field") + + +@pytest.mark.asyncio +async def test_earliest(db, intfields_data): + assert (await IntFields.all().earliest("intnum")).intnum == 10 + assert (await IntFields.all().order_by("intnum").first()).intnum == ( + await IntFields.all().earliest("intnum") + ).intnum + assert (await IntFields.all().filter(intnum__gte=40).earliest("intnum")).intnum == 40 + assert (await IntFields.all().filter(intnum__gte=40).earliest("intnum").values())[ + "intnum" + ] == 40 + assert (await IntFields.all().filter(intnum__gte=40).earliest("intnum").values_list())[1] == 40 + + assert await IntFields.all().filter(intnum__gte=400).earliest("intnum") is None + assert await IntFields.all().filter(intnum__gte=400).earliest("intnum").values() is None + assert await IntFields.all().filter(intnum__gte=400).earliest("intnum").values_list() is None + + with pytest.raises(FieldError): + await IntFields.all().earliest() + + with pytest.raises(FieldError): + await IntFields.all().earliest("some_unkown_field") + + +@pytest.mark.asyncio +async def test_get_or_none(db, intfields_data): + assert (await IntFields.all().get_or_none(intnum=40)).intnum == 40 + assert (await IntFields.all().get_or_none(intnum=40).values())["intnum"] == 40 + assert (await IntFields.all().get_or_none(intnum=40).values_list())[1] == 40 + + assert await IntFields.all().order_by("intnum").get_or_none(intnum__gte=400) is None + + assert await IntFields.all().order_by("intnum").get_or_none(intnum__gte=400).values() is None + + assert ( + await IntFields.all().order_by("intnum").get_or_none(intnum__gte=400).values_list() is None + ) + + with pytest.raises(MultipleObjectsReturned): + await IntFields.all().order_by("intnum").get_or_none(intnum__gte=40) + + with pytest.raises(MultipleObjectsReturned): + await IntFields.all().order_by("intnum").get_or_none(intnum__gte=40).values() + + with pytest.raises(MultipleObjectsReturned): + await IntFields.all().order_by("intnum").get_or_none(intnum__gte=40).values_list() + + +@pytest.mark.asyncio +async def test_get(db, intfields_data): + await IntFields.filter(intnum__gte=70).update(intnum_null=80) + + # Test get + assert (await IntFields.all().get(intnum=40)).intnum == 40 + assert (await IntFields.all().get(intnum=40).values())["intnum"] == 40 + assert (await IntFields.all().get(intnum=40).values_list())[1] == 40 + + assert (await IntFields.all().all().all().all().all().get(intnum=40)).intnum == 40 + assert (await IntFields.all().all().all().all().all().get(intnum=40).values())["intnum"] == 40 + assert (await IntFields.all().all().all().all().all().get(intnum=40).values_list())[1] == 40 + + assert (await IntFields.get(intnum=40)).intnum == 40 + assert (await IntFields.get(intnum=40).values())["intnum"] == 40 + assert (await IntFields.get(intnum=40).values_list())[1] == 40 + + with pytest.raises(DoesNotExist): + await IntFields.all().get(intnum=41) + + with pytest.raises(DoesNotExist): + await IntFields.all().get(intnum=41).values() + + with pytest.raises(DoesNotExist): + await IntFields.all().get(intnum=41).values_list() + + with pytest.raises(DoesNotExist): + await IntFields.get(intnum=41) + + with pytest.raises(DoesNotExist): + await IntFields.get(intnum=41).values() + + with pytest.raises(DoesNotExist): + await IntFields.get(intnum=41).values_list() + + with pytest.raises(MultipleObjectsReturned): + await IntFields.all().get(intnum_null=80) + + with pytest.raises(MultipleObjectsReturned): + await IntFields.all().get(intnum_null=80).values() + + with pytest.raises(MultipleObjectsReturned): + await IntFields.all().get(intnum_null=80).values_list() + + with pytest.raises(MultipleObjectsReturned): + await IntFields.get(intnum_null=80) + + with pytest.raises(MultipleObjectsReturned): + await IntFields.get(intnum_null=80).values() + + with pytest.raises(MultipleObjectsReturned): + await IntFields.get(intnum_null=80).values_list() + + +@pytest.mark.asyncio +async def test_delete(db, intfields_data): + # Test delete + await (await IntFields.get(intnum=40)).delete() + + with pytest.raises(DoesNotExist): + await IntFields.get(intnum=40) + + assert await IntFields.all().count() == 29 + + rows_affected = ( + await IntFields.all().order_by("intnum").limit(10).filter(intnum__gte=70).delete() + ) + assert rows_affected == 10 + + assert await IntFields.all().count() == 19 + + +@requireCapability(support_update_limit_order_by=True) +@pytest.mark.asyncio +async def test_delete_limit(db, intfields_data): + await IntFields.all().limit(1).delete() + assert await IntFields.all().count() == 29 + + +@requireCapability(support_update_limit_order_by=True) +@pytest.mark.asyncio +async def test_delete_limit_order_by(db, intfields_data): + await IntFields.all().limit(1).order_by("-id").delete() + assert await IntFields.all().count() == 29 + with pytest.raises(DoesNotExist): + await IntFields.get(intnum=97) + + +@pytest.mark.asyncio +async def test_async_iter(db, intfields_data): + counter = 0 + async for _ in IntFields.all(): + counter += 1 + + assert await IntFields.all().count() == counter + + +@pytest.mark.asyncio +async def test_update_basic(db): + obj0 = await IntFields.create(intnum=2147483647) + await IntFields.filter(id=obj0.id).update(intnum=2147483646) + obj = await IntFields.get(id=obj0.id) + assert obj.intnum == 2147483646 + assert obj.intnum_null is None + + +@pytest.mark.asyncio +async def test_update_f_expression(db): + obj0 = await IntFields.create(intnum=2147483647) + await IntFields.filter(id=obj0.id).update(intnum=F("intnum") - 1) + obj = await IntFields.get(id=obj0.id) + assert obj.intnum == 2147483646 + + +@pytest.mark.asyncio +async def test_update_badparam(db): + obj0 = await IntFields.create(intnum=2147483647) + with pytest.raises(FieldError, match="Unknown keyword argument"): + await IntFields.filter(id=obj0.id).update(badparam=1) + + +@pytest.mark.asyncio +async def test_update_pk(db): + obj0 = await IntFields.create(intnum=2147483647) + with pytest.raises(IntegrityError, match="is PK and can not be updated"): + await IntFields.filter(id=obj0.id).update(id=1) + + +@pytest.mark.asyncio +async def test_update_virtual(db): + tour = await Tournament.create(name="moo") + obj0 = await MinRelation.create(tournament=tour) + with pytest.raises(FieldError, match="is virtual and can not be updated"): + await MinRelation.filter(id=obj0.id).update(participants=[]) + + +@pytest.mark.asyncio +async def test_bad_ordering(db, intfields_data): + with pytest.raises(FieldError, match="Unknown field moo1fip for model IntFields"): + await IntFields.all().order_by("moo1fip") + + +@pytest.mark.asyncio +async def test_duplicate_values(db, intfields_data): + with pytest.raises(FieldError, match="Duplicate key intnum"): + await IntFields.all().values("intnum", "intnum") + + +@pytest.mark.asyncio +async def test_duplicate_values_list(db, intfields_data): + await IntFields.all().values_list("intnum", "intnum") + + +@pytest.mark.asyncio +async def test_duplicate_values_kw(db, intfields_data): + with pytest.raises(FieldError, match="Duplicate key intnum"): + await IntFields.all().values("intnum", intnum="intnum_null") + + +@pytest.mark.asyncio +async def test_duplicate_values_kw_badmap(db, intfields_data): + with pytest.raises(FieldError, match='Unknown field "intnum2" for model "IntFields"'): + await IntFields.all().values(intnum="intnum2") + + +@pytest.mark.asyncio +async def test_bad_values(db, intfields_data): + with pytest.raises(FieldError, match='Unknown field "int2num" for model "IntFields"'): + await IntFields.all().values("int2num") + + +@pytest.mark.asyncio +async def test_bad_values_list(db, intfields_data): + with pytest.raises(FieldError, match='Unknown field "int2num" for model "IntFields"'): + await IntFields.all().values_list("int2num") + + +@pytest.mark.asyncio +async def test_many_flat_values_list(db, intfields_data): + with pytest.raises(TypeError, match="You can flat value_list only if contains one field"): + await IntFields.all().values_list("intnum", "intnum_null", flat=True) + + +@pytest.mark.asyncio +async def test_all_flat_values_list(db, intfields_data): + with pytest.raises(TypeError, match="You can flat value_list only if contains one field"): + await IntFields.all().values_list(flat=True) + + +@pytest.mark.asyncio +async def test_all_values_list(db, intfields_data): + data = await IntFields.all().order_by("id").values_list() + assert data[2] == (intfields_data[2].id, 16, None) + + +@pytest.mark.asyncio +async def test_all_values(db, intfields_data): + data = await IntFields.all().order_by("id").values() + assert data[2] == {"id": intfields_data[2].id, "intnum": 16, "intnum_null": None} + + +@pytest.mark.asyncio +async def test_order_by_bad_value(db, intfields_data): + with pytest.raises(FieldError, match="Unknown field badid for model IntFields"): + await IntFields.all().order_by("badid").values_list() + + +@pytest.mark.asyncio +async def test_annotate_order_expression(db, intfields_data): + data = ( + await IntFields.annotate(idp=F("id") + 1).order_by("-idp").first().values_list("id", "idp") + ) + assert data[0] + 1 == data[1] + + +@pytest.mark.asyncio +async def test_annotate_order_rawsql(db, intfields_data): + qs = IntFields.annotate(idp=RawSQL("id+1")).order_by("-idp") + data = await qs.first().values_list("id", "idp") + assert data[0] + 1 == data[1] + + +@pytest.mark.asyncio +async def test_annotate_expression_filter(db, intfields_data): + count = await IntFields.annotate(intnum1=F("intnum") + 1).filter(intnum1__gt=30).count() + assert count == 23 + + +@pytest.mark.asyncio +async def test_get_raw_sql(db, intfields_data): + sql = IntFields.all().sql() + assert "SELECT" in sql and "FROM" in sql + + +@requireCapability(support_index_hint=True) +@pytest.mark.asyncio +async def test_force_index(db, intfields_data): + sql = IntFields.filter(pk=1).only("id").force_index("index_name").sql() + assert sql == "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s" + + sql_again = IntFields.filter(pk=1).only("id").force_index("index_name").sql() + assert sql_again == "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s" + + +@requireCapability(support_index_hint=True) +@pytest.mark.asyncio +async def test_force_index_available_in_more_query(db, intfields_data): + sql_ValuesQuery = IntFields.filter(pk=1).force_index("index_name").values("id").sql() + assert ( + sql_ValuesQuery + == "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s" + ) + + sql_ValuesListQuery = IntFields.filter(pk=1).force_index("index_name").values_list("id").sql() + assert ( + sql_ValuesListQuery + == "SELECT `id` `0` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s" + ) + + sql_CountQuery = IntFields.filter(pk=1).force_index("index_name").count().sql() + assert ( + sql_CountQuery + == "SELECT COUNT(*) FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s" + ) + + sql_ExistsQuery = IntFields.filter(pk=1).force_index("index_name").exists().sql() + assert ( + sql_ExistsQuery + == "SELECT 1 FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s LIMIT %s" + ) + + +@requireCapability(support_index_hint=True) +@pytest.mark.asyncio +async def test_use_index(db, intfields_data): + sql = IntFields.filter(pk=1).only("id").use_index("index_name").sql() + assert sql == "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s" + + sql_again = IntFields.filter(pk=1).only("id").use_index("index_name").sql() + assert sql_again == "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s" + + +@requireCapability(support_index_hint=True) +@pytest.mark.asyncio +async def test_use_index_available_in_more_query(db, intfields_data): + sql_ValuesQuery = IntFields.filter(pk=1).use_index("index_name").values("id").sql() + assert ( + sql_ValuesQuery + == "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s" + ) + + sql_ValuesListQuery = IntFields.filter(pk=1).use_index("index_name").values_list("id").sql() + assert ( + sql_ValuesListQuery + == "SELECT `id` `0` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s" + ) + + sql_CountQuery = IntFields.filter(pk=1).use_index("index_name").count().sql() + assert ( + sql_CountQuery == "SELECT COUNT(*) FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s" + ) + + sql_ExistsQuery = IntFields.filter(pk=1).use_index("index_name").exists().sql() + assert ( + sql_ExistsQuery + == "SELECT 1 FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s LIMIT %s" + ) + + +@requireCapability(support_for_update=True) +@pytest.mark.asyncio +async def test_select_for_update(db, intfields_data): + sql1 = IntFields.filter(pk=1).only("id").select_for_update().sql() + sql2 = IntFields.filter(pk=1).only("id").select_for_update(nowait=True).sql() + sql3 = IntFields.filter(pk=1).only("id").select_for_update(skip_locked=True).sql() + sql4 = IntFields.filter(pk=1).only("id").select_for_update(of=("intfields",)).sql() + sql5 = IntFields.filter(pk=1).only("id").select_for_update(no_key=True).sql() + + db_conn = connections.get("models") + dialect = db_conn.schema_generator.DIALECT + if dialect == "postgres": + if isinstance(db_conn, PsycopgClient): + assert sql1 == 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE' + assert sql2 == 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE NOWAIT' + assert sql3 == 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE SKIP LOCKED' + assert ( + sql4 == 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE OF "intfields"' ) - self.assertEqual( - sql5, - "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE", + assert sql5 == 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR NO KEY UPDATE' + else: + assert sql1 == 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE' + assert sql2 == 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE NOWAIT' + assert sql3 == 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE SKIP LOCKED' + assert ( + sql4 == 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE OF "intfields"' ) + assert sql5 == 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR NO KEY UPDATE' + elif dialect == "mysql": + assert sql1 == "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE" + assert sql2 == "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE NOWAIT" + assert sql3 == "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE SKIP LOCKED" + assert sql4 == "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE OF `intfields`" + assert sql5 == "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE" + + +@pytest.mark.asyncio +async def test_select_related(db): + tournament = await Tournament.create(name="1") + reporter = await Reporter.create(name="Reporter") + event = await Event.create(name="1", tournament=tournament, reporter=reporter) + event = await Event.all().select_related("tournament", "reporter").get(pk=event.pk) + assert event.tournament.pk == tournament.pk + assert event.reporter.pk == reporter.pk + + +@pytest.mark.asyncio +async def test_select_related_with_two_same_models(db): + parent_node = await Node.create(name="1") + child_node = await Node.create(name="2") + tree = await Tree.create(parent=parent_node, child=child_node) + tree = await Tree.all().select_related("parent", "child").get(pk=tree.pk) + assert tree.parent.pk == parent_node.pk + assert tree.parent.name == parent_node.name + assert tree.child.pk == child_node.pk + assert tree.child.name == child_node.name + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_postgres_search(db): + name = "hello world" + await Tournament.create(name=name) + ret = await Tournament.filter(name__search="hello").first() + assert ret.name == name + + +@pytest.mark.asyncio +async def test_subquery_select(db): + t1 = await Tournament.create(name="1") + ret = ( + await Tournament.filter(pk=t1.pk) + .annotate(ids=Subquery(Tournament.filter(pk=t1.pk).values("id"))) + .values("ids", "id") + ) + assert ret == [{"id": t1.pk, "ids": t1.pk}] + + +@pytest.mark.asyncio +async def test_subquery_filter(db): + t1 = await Tournament.create(name="1") + ret = await Tournament.filter(pk=Subquery(Tournament.filter(pk=t1.pk).values("id"))).first() + assert ret == t1 + + +@pytest.mark.asyncio +async def test_raw_sql_count(db): + t1 = await Tournament.create(name="1") + ret = await Tournament.filter(pk=t1.pk).annotate(count=RawSQL("count(*)")).values("count") + assert ret == [{"count": 1}] + + +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_raw_sql_select(db): + t1 = await Tournament.create(id=1, name="1") + ret = ( + await Tournament.filter(pk=t1.pk).annotate(idp=RawSQL("id + 1")).filter(idp=2).values("idp") + ) + assert ret == [{"idp": 2}] + + +@pytest.mark.asyncio +async def test_raw_sql_filter(db): + ret = await Tournament.filter(pk=RawSQL("id + 1")) + assert ret == [] + + +@pytest.mark.asyncio +async def test_annotation_field_priorior_to_model_field(db): + # Sometimes, field name in annotates also exist in model field sets + # and may need lift the former's priority in select query construction. + t1 = await Tournament.create(name="1") + ret = await Tournament.filter(pk=t1.pk).annotate(id=RawSQL("id + 1")).values("id") + assert ret == [{"id": t1.pk + 1}] + + +@pytest.mark.asyncio +async def test_f_annotation_referenced_in_annotation(db): + instance = await IntFields.create(intnum=1) + + events = ( + await IntFields.filter(id=instance.id) + .annotate(intnum_plus_1=F("intnum") + 1) + .annotate(intnum_plus_2=F("intnum_plus_1") + 1) + ) + assert len(events) == 1 + assert events[0].intnum_plus_1 == 2 + assert events[0].intnum_plus_2 == 3 + + # in a single annotate call + events = await IntFields.filter(id=instance.id).annotate( + intnum_plus_1=F("intnum") + 1, intnum_plus_2=F("intnum_plus_1") + 1 + ) + assert len(events) == 1 + assert events[0].intnum_plus_1 == 2 + assert events[0].intnum_plus_2 == 3 + + +@pytest.mark.asyncio +async def test_rawsql_annotation_referenced_in_annotation(db): + instance = await IntFields.create(intnum=1) + + events = ( + await IntFields.filter(id=instance.id) + .annotate(ten=RawSQL("20 / 2")) + .annotate(ten_plus_1=F("ten") + 1) + ) + + assert len(events) == 1 + assert events[0].ten == 10 + assert events[0].ten_plus_1 == 11 + + +@pytest.mark.asyncio +async def test_joins_in_arithmetic_expressions(db): + author = await Author.create(name="1") + await Book.create(name="1", author=author, rating=1) + await Book.create(name="2", author=author, rating=5) + + ret = await Author.annotate(rating=Avg(F("books__rating") + 1)) + assert len(ret) == 1 + assert ret[0].rating == 4.0 + + ret = await Author.annotate(rating=Avg(F("books__rating") * 2 - F("books__rating"))) + assert len(ret) == 1 + assert ret[0].rating == 3.0 + + +@pytest.mark.asyncio +async def test_annotations_in_flat_values_list(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + await Book.create(name="1", author=author1, rating=1) + await Book.create(name="2", author=author2, rating=3) + await Book.create(name="3", author=author3, rating=5) + + subquery = Author.annotate(rating=Avg("books__rating")).filter(rating__gte=3) + + subquery_ret = await subquery.order_by("id").values_list("id", flat=True) + assert len(subquery_ret) == 2 + assert subquery_ret[0] == author2.pk + assert subquery_ret[1] == author3.pk + + ret = await Author.filter(id__in=Subquery(subquery.values_list("id", flat=True))).order_by("id") + assert ret[0] == author2 + assert ret[1] == author3 - async def test_select_related(self): - tournament = await Tournament.create(name="1") - reporter = await Reporter.create(name="Reporter") - event = await Event.create(name="1", tournament=tournament, reporter=reporter) - event = await Event.all().select_related("tournament", "reporter").get(pk=event.pk) - self.assertEqual(event.tournament.pk, tournament.pk) - self.assertEqual(event.reporter.pk, reporter.pk) - - async def test_select_related_with_two_same_models(self): - parent_node = await Node.create(name="1") - child_node = await Node.create(name="2") - tree = await Tree.create(parent=parent_node, child=child_node) - tree = await Tree.all().select_related("parent", "child").get(pk=tree.pk) - self.assertEqual(tree.parent.pk, parent_node.pk) - self.assertEqual(tree.parent.name, parent_node.name) - self.assertEqual(tree.child.pk, child_node.pk) - self.assertEqual(tree.child.name, child_node.name) - - @test.requireCapability(dialect="postgres") - async def test_postgres_search(self): - name = "hello world" - await Tournament.create(name=name) - ret = await Tournament.filter(name__search="hello").first() - self.assertEqual(ret.name, name) - - async def test_subquery_select(self): - t1 = await Tournament.create(name="1") - ret = ( - await Tournament.filter(pk=t1.pk) - .annotate(ids=Subquery(Tournament.filter(pk=t1.pk).values("id"))) - .values("ids", "id") - ) - self.assertEqual(ret, [{"id": t1.pk, "ids": t1.pk}]) - - async def test_subquery_filter(self): - t1 = await Tournament.create(name="1") - ret = await Tournament.filter(pk=Subquery(Tournament.filter(pk=t1.pk).values("id"))).first() - self.assertEqual(ret, t1) - - async def test_raw_sql_count(self): - t1 = await Tournament.create(name="1") - ret = await Tournament.filter(pk=t1.pk).annotate(count=RawSQL("count(*)")).values("count") - self.assertEqual(ret, [{"count": 1}]) - - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_raw_sql_select(self): - t1 = await Tournament.create(id=1, name="1") - ret = ( - await Tournament.filter(pk=t1.pk) - .annotate(idp=RawSQL("id + 1")) - .filter(idp=2) - .values("idp") - ) - self.assertEqual(ret, [{"idp": 2}]) - - async def test_raw_sql_filter(self): - ret = await Tournament.filter(pk=RawSQL("id + 1")) - self.assertEqual(ret, []) - - async def test_annotation_field_priorior_to_model_field(self): - # Sometimes, field name in annotates also exist in model field sets - # and may need lift the former's priority in select query construction. - t1 = await Tournament.create(name="1") - ret = await Tournament.filter(pk=t1.pk).annotate(id=RawSQL("id + 1")).values("id") - self.assertEqual(ret, [{"id": t1.pk + 1}]) - - async def test_f_annotation_referenced_in_annotation(self): - instance = await IntFields.create(intnum=1) - - events = ( - await IntFields.filter(id=instance.id) - .annotate(intnum_plus_1=F("intnum") + 1) - .annotate(intnum_plus_2=F("intnum_plus_1") + 1) - ) - self.assertEqual(len(events), 1) - self.assertEqual(events[0].intnum_plus_1, 2) - self.assertEqual(events[0].intnum_plus_2, 3) - - # in a single annotate call - events = await IntFields.filter(id=instance.id).annotate( - intnum_plus_1=F("intnum") + 1, intnum_plus_2=F("intnum_plus_1") + 1 - ) - self.assertEqual(len(events), 1) - self.assertEqual(events[0].intnum_plus_1, 2) - self.assertEqual(events[0].intnum_plus_2, 3) - - async def test_rawsql_annotation_referenced_in_annotation(self): - instance = await IntFields.create(intnum=1) - - events = ( - await IntFields.filter(id=instance.id) - .annotate(ten=RawSQL("20 / 2")) - .annotate(ten_plus_1=F("ten") + 1) - ) - - self.assertEqual(len(events), 1) - self.assertEqual(events[0].ten, 10) - self.assertEqual(events[0].ten_plus_1, 11) - - async def test_joins_in_arithmetic_expressions(self): - author = await Author.create(name="1") - await Book.create(name="1", author=author, rating=1) - await Book.create(name="2", author=author, rating=5) - - ret = await Author.annotate(rating=Avg(F("books__rating") + 1)) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0].rating, 4.0) - - ret = await Author.annotate(rating=Avg(F("books__rating") * 2 - F("books__rating"))) - self.assertEqual(len(ret), 1) - self.assertEqual(ret[0].rating, 3.0) - - async def test_annotations_in_flat_values_list(self): - author1 = await Author.create(name="1") - author2 = await Author.create(name="2") - author3 = await Author.create(name="3") - await Book.create(name="1", author=author1, rating=1) - await Book.create(name="2", author=author2, rating=3) - await Book.create(name="3", author=author3, rating=5) - - subquery = Author.annotate(rating=Avg("books__rating")).filter(rating__gte=3) - - subquery_ret = await subquery.order_by("id").values_list("id", flat=True) - self.assertEqual(len(subquery_ret), 2) - self.assertEqual(subquery_ret[0], author2.pk) - self.assertEqual(subquery_ret[1], author3.pk) - - ret = await Author.filter(id__in=Subquery(subquery.values_list("id", flat=True))).order_by( - "id" - ) - self.assertEqual(ret[0], author2) - self.assertEqual(ret[1], author3) - - -class TestNotExist(test.TestCase): - exp_cls: type[NotExistOrMultiple] = DoesNotExist - @test.requireCapability(dialect="sqlite") - def test_does_not_exist(self): - assert str(self.exp_cls("old format")) == "old format" - assert str(self.exp_cls(Tournament)) == self.exp_cls.TEMPLATE.format(Tournament.__name__) +# Tests for exception classes (no database needed, pure Python tests) +def test_does_not_exist(): + exp_cls: type[NotExistOrMultiple] = DoesNotExist + assert str(exp_cls("old format")) == "old format" + assert str(exp_cls(Tournament)) == exp_cls.TEMPLATE.format(Tournament.__name__) -class TestMultiple(TestNotExist): - exp_cls = MultipleObjectsReturned +def test_multiple_objects_returned(): + exp_cls: type[NotExistOrMultiple] = MultipleObjectsReturned + assert str(exp_cls("old format")) == "old format" + assert str(exp_cls(Tournament)) == exp_cls.TEMPLATE.format(Tournament.__name__) diff --git a/tests/test_queryset_reuse.py b/tests/test_queryset_reuse.py index 7d7adf95e..6327f902c 100644 --- a/tests/test_queryset_reuse.py +++ b/tests/test_queryset_reuse.py @@ -1,3 +1,5 @@ +import pytest + from tests.testmodels import Event, Tournament from tortoise.contrib import test from tortoise.contrib.test.condition import NotEQ @@ -5,81 +7,89 @@ from tortoise.functions import Length -class TestQueryReuse(test.TestCase): - async def test_annotations(self): - a = await Tournament.create(name="A") - - base_query = Tournament.annotate(id_plus_one=F("id") + 1) - query1 = base_query.annotate(id_plus_two=F("id") + 2) - query2 = base_query.annotate(id_plus_three=F("id") + 3) - res = await query1.first() - self.assertEqual(res.id_plus_one, a.id + 1) - self.assertEqual(res.id_plus_two, a.id + 2) - with self.assertRaises(AttributeError): - getattr(res, "id_plus_three") - - res = await query2.first() - self.assertEqual(res.id_plus_one, a.id + 1) - self.assertEqual(res.id_plus_three, a.id + 3) - with self.assertRaises(AttributeError): - getattr(res, "id_plus_two") - - res = await query1.first() - with self.assertRaises(AttributeError): - getattr(res, "id_plus_three") - - async def test_filters(self): - a = await Tournament.create(name="A") - b = await Tournament.create(name="B") - await Tournament.create(name="C") - - base_query = Tournament.exclude(name="C") - tournaments = await base_query - self.assertSetEqual(set(tournaments), {a, b}) - - tournaments = await base_query.exclude(name="A") - self.assertSetEqual(set(tournaments), {b}) - - tournaments = await base_query.exclude(name="B") - self.assertSetEqual(set(tournaments), {a}) - - async def test_joins(self): - tournament_a = await Tournament.create(name="A") - tournament_b = await Tournament.create(name="B") - tournament_c = await Tournament.create(name="C") - event_a = await Event.create(name="A", tournament=tournament_a) - event_b = await Event.create(name="B", tournament=tournament_b) - await Event.create(name="C", tournament=tournament_c) - - base_query = Event.exclude(tournament__name="C") - events = await base_query - self.assertSetEqual(set(events), {event_a, event_b}) - - events = await base_query.exclude(name="A") - self.assertSetEqual(set(events), {event_b}) - - events = await base_query.exclude(name="B") - self.assertSetEqual(set(events), {event_a}) - - async def test_order_by(self): - a = await Tournament.create(name="A") - b = await Tournament.create(name="B") - - base_query = Tournament.all().order_by("name") - tournaments = await base_query - self.assertEqual(tournaments, [a, b]) - - tournaments = await base_query.order_by("-name") - self.assertEqual(tournaments, [b, a]) - - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_values_with_annotations(self): - await Tournament.create(name="Championship") - await Tournament.create(name="Super Bowl") - - base_query = Tournament.annotate(name_length=Length("name")) - tournaments = await base_query.values_list("name") - self.assertListSortEqual(tournaments, [("Championship",), ("Super Bowl",)]) - - tournaments = await base_query.values_list("name_length") - self.assertListSortEqual(tournaments, [(10,), (12,)]) +@pytest.mark.asyncio +async def test_annotations(db): + a = await Tournament.create(name="A") + + base_query = Tournament.annotate(id_plus_one=F("id") + 1) + query1 = base_query.annotate(id_plus_two=F("id") + 2) + query2 = base_query.annotate(id_plus_three=F("id") + 3) + res = await query1.first() + assert res.id_plus_one == a.id + 1 + assert res.id_plus_two == a.id + 2 + with pytest.raises(AttributeError): + getattr(res, "id_plus_three") + + res = await query2.first() + assert res.id_plus_one == a.id + 1 + assert res.id_plus_three == a.id + 3 + with pytest.raises(AttributeError): + getattr(res, "id_plus_two") + + res = await query1.first() + with pytest.raises(AttributeError): + getattr(res, "id_plus_three") + + +@pytest.mark.asyncio +async def test_filters(db): + a = await Tournament.create(name="A") + b = await Tournament.create(name="B") + await Tournament.create(name="C") + + base_query = Tournament.exclude(name="C") + tournaments = await base_query + assert set(tournaments) == {a, b} + + tournaments = await base_query.exclude(name="A") + assert set(tournaments) == {b} + + tournaments = await base_query.exclude(name="B") + assert set(tournaments) == {a} + + +@pytest.mark.asyncio +async def test_joins(db): + tournament_a = await Tournament.create(name="A") + tournament_b = await Tournament.create(name="B") + tournament_c = await Tournament.create(name="C") + event_a = await Event.create(name="A", tournament=tournament_a) + event_b = await Event.create(name="B", tournament=tournament_b) + await Event.create(name="C", tournament=tournament_c) + + base_query = Event.exclude(tournament__name="C") + events = await base_query + assert set(events) == {event_a, event_b} + + events = await base_query.exclude(name="A") + assert set(events) == {event_b} + + events = await base_query.exclude(name="B") + assert set(events) == {event_a} + + +@pytest.mark.asyncio +async def test_order_by(db): + a = await Tournament.create(name="A") + b = await Tournament.create(name="B") + + base_query = Tournament.all().order_by("name") + tournaments = await base_query + assert tournaments == [a, b] + + tournaments = await base_query.order_by("-name") + assert tournaments == [b, a] + + +@test.requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_values_with_annotations(db): + await Tournament.create(name="Championship") + await Tournament.create(name="Super Bowl") + + base_query = Tournament.annotate(name_length=Length("name")) + tournaments = await base_query.values_list("name") + assert sorted(tournaments) == sorted([("Championship",), ("Super Bowl",)]) + + tournaments = await base_query.values_list("name_length") + assert sorted(tournaments) == sorted([(10,), (12,)]) diff --git a/tests/test_relations.py b/tests/test_relations.py index bdc80be5d..6afe270ec 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,4 +1,9 @@ +import re import subprocess # nosec +import sys + +import pytest +import pytest_asyncio from tests.testmodels import ( Address, @@ -18,178 +23,204 @@ Tournament, UUIDFkRelatedNullModel, ) -from tortoise.contrib import test +from tortoise.contrib.test import requireCapability from tortoise.contrib.test.condition import NotIn from tortoise.exceptions import FieldError, NoValuesFetched, OperationalError from tortoise.functions import Count, Trim +# ============================================================================= +# TestRelations - uses db fixture (transaction rollback) +# ============================================================================= + + +@pytest.mark.asyncio +async def test_relations(db): + tournament = Tournament(name="New Tournament") + await tournament.save() + await Event(name="Without participants", tournament_id=tournament.id).save() + event = Event(name="Test", tournament_id=tournament.id) + await event.save() + participants = [] + for i in range(2): + team = Team(name=f"Team {(i + 1)}") + await team.save() + participants.append(team) + await event.participants.add(participants[0], participants[1]) + await event.participants.add(participants[0], participants[1]) + + with pytest.raises(NoValuesFetched): + [team.id for team in event.participants] # pylint: disable=W0104 + + teamids = [] + async for team in event.participants: + teamids.append(team.id) + assert set(teamids) == {participants[0].id, participants[1].id} + teamids = [team.id async for team in event.participants] + assert set(teamids) == {participants[0].id, participants[1].id} + + assert {team.id for team in event.participants} == {participants[0].id, participants[1].id} + + assert event.participants[0].id in {participants[0].id, participants[1].id} + + selected_events = await Event.filter(participants=participants[0].id).prefetch_related( + "participants", "tournament" + ) + assert len(selected_events) == 1 + assert selected_events[0].tournament.id == tournament.id + assert len(selected_events[0].participants) == 2 + await participants[0].fetch_related("events") + assert participants[0].events[0] == event + + await Team.fetch_for_list(participants, "events") + + await Team.filter(events__tournament__id=tournament.id) + + await Event.filter(tournament=tournament) + + await Tournament.filter(events__name__in=["Test", "Prod"]).distinct() -class TestRelations(test.TestCase): - async def test_relations(self): - tournament = Tournament(name="New Tournament") - await tournament.save() - await Event(name="Without participants", tournament_id=tournament.id).save() - event = Event(name="Test", tournament_id=tournament.id) - await event.save() - participants = [] - for i in range(2): - team = Team(name=f"Team {(i + 1)}") - await team.save() - participants.append(team) - await event.participants.add(participants[0], participants[1]) - await event.participants.add(participants[0], participants[1]) - - with self.assertRaises(NoValuesFetched): - [team.id for team in event.participants] # pylint: disable=W0104 - - teamids = [] - async for team in event.participants: - teamids.append(team.id) - self.assertEqual(set(teamids), {participants[0].id, participants[1].id}) - teamids = [team.id async for team in event.participants] - self.assertEqual(set(teamids), {participants[0].id, participants[1].id}) - - self.assertEqual( - {team.id for team in event.participants}, {participants[0].id, participants[1].id} - ) - - self.assertIn(event.participants[0].id, {participants[0].id, participants[1].id}) - - selected_events = await Event.filter(participants=participants[0].id).prefetch_related( - "participants", "tournament" - ) - self.assertEqual(len(selected_events), 1) - self.assertEqual(selected_events[0].tournament.id, tournament.id) - self.assertEqual(len(selected_events[0].participants), 2) - await participants[0].fetch_related("events") - self.assertEqual(participants[0].events[0], event) - - await Team.fetch_for_list(participants, "events") - - await Team.filter(events__tournament__id=tournament.id) - - await Event.filter(tournament=tournament) - - await Tournament.filter(events__name__in=["Test", "Prod"]).distinct() - - result = await Event.filter(pk=event.pk).values( - "event_id", "name", tournament="tournament__name" - ) - self.assertEqual(result[0]["tournament"], tournament.name) - - result = await Event.filter(pk=event.pk).values_list("event_id", "participants__name") - self.assertEqual(len(result), 2) - - async def test_reset_queryset_on_query(self): - tournament = await Tournament.create(name="New Tournament") - event = await Event.create(name="Test", tournament_id=tournament.id) - participants = [] - for i in range(2): - team = await Team.create(name=f"Team {(i + 1)}") - participants.append(team) - await event.participants.add(*participants) - queryset = Event.all().annotate(count=Count("participants")) - await queryset.first() - await queryset.filter(name="Test").first() - - async def test_bool_for_relation_new_object(self): - tournament = await Tournament.create(name="New Tournament") - - with self.assertRaises(NoValuesFetched): - bool(tournament.events) - - async def test_bool_for_relation_old_object(self): - await Tournament.create(name="New Tournament") - tournament = await Tournament.first() - - with self.assertRaises(NoValuesFetched): - bool(tournament.events) - - async def test_bool_for_relation_fetched_false(self): - tournament = await Tournament.create(name="New Tournament") - await tournament.fetch_related("events") - - self.assertFalse(bool(tournament.events)) - - async def test_bool_for_relation_fetched_true(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) - await tournament.fetch_related("events") - - self.assertTrue(bool(tournament.events)) - - async def test_m2m_add(self): - tournament = await Tournament.create(name="tournament") - event = await Event.create(name="First", tournament=tournament) - team = await Team.create(name="1") - team_second = await Team.create(name="2") - await event.participants.add(team, team_second) - fetched_event = await Event.first().prefetch_related("participants") - self.assertEqual(len(fetched_event.participants), 2) - - async def test_m2m_add_already_added(self): - tournament = await Tournament.create(name="tournament") - event = await Event.create(name="First", tournament=tournament) - team = await Team.create(name="1") - team_second = await Team.create(name="2") - await event.participants.add(team, team_second) - await event.participants.add(team, team_second) - fetched_event = await Event.first().prefetch_related("participants") - self.assertEqual(len(fetched_event.participants), 2) - - async def test_m2m_clear(self): - tournament = await Tournament.create(name="tournament") - event = await Event.create(name="First", tournament=tournament) - team = await Team.create(name="1") - team_second = await Team.create(name="2") - await event.participants.add(team, team_second) - await event.participants.clear() - fetched_event = await Event.first().prefetch_related("participants") - self.assertEqual(len(fetched_event.participants), 0) - - async def test_m2m_remove(self): - tournament = await Tournament.create(name="tournament") - event = await Event.create(name="First", tournament=tournament) - team = await Team.create(name="1") - team_second = await Team.create(name="2") - await event.participants.add(team, team_second) - await event.participants.remove(team) - fetched_event = await Event.first().prefetch_related("participants") - self.assertEqual(len(fetched_event.participants), 1) - - async def test_o2o_lazy(self): - tournament = await Tournament.create(name="tournament") - event = await Event.create(name="First", tournament=tournament) - await Address.create(city="Santa Monica", street="Ocean", event=event) - - fetched_address = await event.address - self.assertEqual(fetched_address.city, "Santa Monica") - - async def test_m2m_remove_two(self): - tournament = await Tournament.create(name="tournament") - event = await Event.create(name="First", tournament=tournament) - team = await Team.create(name="1") - team_second = await Team.create(name="2") - await event.participants.add(team, team_second) - await event.participants.remove(team, team_second) - fetched_event = await Event.first().prefetch_related("participants") - self.assertEqual(len(fetched_event.participants), 0) - - async def test_self_ref(self): - root = await Employee.create(name="Root") - loose = await Employee.create(name="Loose") - _1 = await Employee.create(name="1. First H1", manager=root) - _2 = await Employee.create(name="2. Second H1", manager=root) - _1_1 = await Employee.create(name="1.1. First H2", manager=_1) - _1_1_1 = await Employee.create(name="1.1.1. First H3", manager=_1_1) - _2_1 = await Employee.create(name="2.1. Second H2", manager=_2) - _2_2 = await Employee.create(name="2.2. Third H2", manager=_2) - - await _1.talks_to.add(_2, _1_1_1, loose) - await _2_1.gets_talked_to.add(_2_2, _1_1, loose) - - LOOSE_TEXT = "Loose (to: 2.1. Second H2) (from: 1. First H1)" - ROOT_TEXT = """Root (to: ) (from: ) + result = await Event.filter(pk=event.pk).values( + "event_id", "name", tournament="tournament__name" + ) + assert result[0]["tournament"] == tournament.name + + result = await Event.filter(pk=event.pk).values_list("event_id", "participants__name") + assert len(result) == 2 + + +@pytest.mark.asyncio +async def test_reset_queryset_on_query(db): + tournament = await Tournament.create(name="New Tournament") + event = await Event.create(name="Test", tournament_id=tournament.id) + participants = [] + for i in range(2): + team = await Team.create(name=f"Team {(i + 1)}") + participants.append(team) + await event.participants.add(*participants) + queryset = Event.all().annotate(count=Count("participants")) + await queryset.first() + await queryset.filter(name="Test").first() + + +@pytest.mark.asyncio +async def test_bool_for_relation_new_object(db): + tournament = await Tournament.create(name="New Tournament") + + with pytest.raises(NoValuesFetched): + bool(tournament.events) + + +@pytest.mark.asyncio +async def test_bool_for_relation_old_object(db): + await Tournament.create(name="New Tournament") + tournament = await Tournament.first() + + with pytest.raises(NoValuesFetched): + bool(tournament.events) + + +@pytest.mark.asyncio +async def test_bool_for_relation_fetched_false(db): + tournament = await Tournament.create(name="New Tournament") + await tournament.fetch_related("events") + + assert not bool(tournament.events) + + +@pytest.mark.asyncio +async def test_bool_for_relation_fetched_true(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) + await tournament.fetch_related("events") + + assert bool(tournament.events) + + +@pytest.mark.asyncio +async def test_m2m_add(db): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + team = await Team.create(name="1") + team_second = await Team.create(name="2") + await event.participants.add(team, team_second) + fetched_event = await Event.first().prefetch_related("participants") + assert len(fetched_event.participants) == 2 + + +@pytest.mark.asyncio +async def test_m2m_add_already_added(db): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + team = await Team.create(name="1") + team_second = await Team.create(name="2") + await event.participants.add(team, team_second) + await event.participants.add(team, team_second) + fetched_event = await Event.first().prefetch_related("participants") + assert len(fetched_event.participants) == 2 + + +@pytest.mark.asyncio +async def test_m2m_clear(db): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + team = await Team.create(name="1") + team_second = await Team.create(name="2") + await event.participants.add(team, team_second) + await event.participants.clear() + fetched_event = await Event.first().prefetch_related("participants") + assert len(fetched_event.participants) == 0 + + +@pytest.mark.asyncio +async def test_m2m_remove(db): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + team = await Team.create(name="1") + team_second = await Team.create(name="2") + await event.participants.add(team, team_second) + await event.participants.remove(team) + fetched_event = await Event.first().prefetch_related("participants") + assert len(fetched_event.participants) == 1 + + +@pytest.mark.asyncio +async def test_o2o_lazy(db): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + await Address.create(city="Santa Monica", street="Ocean", event=event) + + fetched_address = await event.address + assert fetched_address.city == "Santa Monica" + + +@pytest.mark.asyncio +async def test_m2m_remove_two(db): + tournament = await Tournament.create(name="tournament") + event = await Event.create(name="First", tournament=tournament) + team = await Team.create(name="1") + team_second = await Team.create(name="2") + await event.participants.add(team, team_second) + await event.participants.remove(team, team_second) + fetched_event = await Event.first().prefetch_related("participants") + assert len(fetched_event.participants) == 0 + + +@pytest.mark.asyncio +async def test_self_ref(db): + root = await Employee.create(name="Root") + loose = await Employee.create(name="Loose") + _1 = await Employee.create(name="1. First H1", manager=root) + _2 = await Employee.create(name="2. Second H1", manager=root) + _1_1 = await Employee.create(name="1.1. First H2", manager=_1) + _1_1_1 = await Employee.create(name="1.1.1. First H3", manager=_1_1) + _2_1 = await Employee.create(name="2.1. Second H2", manager=_2) + _2_2 = await Employee.create(name="2.2. Third H2", manager=_2) + + await _1.talks_to.add(_2, _1_1_1, loose) + await _2_1.gets_talked_to.add(_2_2, _1_1, loose) + + LOOSE_TEXT = "Loose (to: 2.1. Second H2) (from: 1. First H1)" + ROOT_TEXT = """Root (to: ) (from: ) 1. First H1 (to: 1.1.1. First H3, 2. Second H1, Loose) (from: ) 1.1. First H2 (to: 2.1. Second H2) (from: ) 1.1.1. First H3 (to: ) (from: 1. First H1) @@ -197,318 +228,378 @@ async def test_self_ref(self): 2.1. Second H2 (to: ) (from: 1.1. First H2, 2.2. Third H2, Loose) 2.2. Third H2 (to: 2.1. Second H2) (from: )""" - # Evaluated off creation objects - self.assertEqual(await loose.full_hierarchy__async_for(), LOOSE_TEXT) - self.assertEqual(await loose.full_hierarchy__fetch_related(), LOOSE_TEXT) - self.assertEqual(await root.full_hierarchy__async_for(), ROOT_TEXT) - self.assertEqual(await root.full_hierarchy__fetch_related(), ROOT_TEXT) - - # Evaluated off new objects → Result is identical - root2 = await Employee.get(name="Root") - loose2 = await Employee.get(name="Loose") - self.assertEqual(await loose2.full_hierarchy__async_for(), LOOSE_TEXT) - self.assertEqual(await loose2.full_hierarchy__fetch_related(), LOOSE_TEXT) - self.assertEqual(await root2.full_hierarchy__async_for(), ROOT_TEXT) - self.assertEqual(await root2.full_hierarchy__fetch_related(), ROOT_TEXT) - - async def test_self_ref_filter_by_child(self): - root = await Employee.create(name="Root") - await Employee.create(name="1. First H1", manager=root) - await Employee.create(name="2. Second H1", manager=root) - - root2 = await Employee.get(team_members__name="1. First H1") - self.assertEqual(root.id, root2.id) - - async def test_self_ref_filter_both(self): - root = await Employee.create(name="Root") - await Employee.create(name="1. First H1", manager=root) - await Employee.create(name="2. Second H1", manager=root) - - root2 = await Employee.get(name="Root", team_members__name="1. First H1") - self.assertEqual(root.id, root2.id) - - async def test_self_ref_annotate(self): - root = await Employee.create(name="Root") - await Employee.create(name="Loose") - await Employee.create(name="1. First H1", manager=root) - await Employee.create(name="2. Second H1", manager=root) - - root_ann = await Employee.get(name="Root").annotate(num_team_members=Count("team_members")) - self.assertEqual(root_ann.num_team_members, 2) - root_ann = await Employee.get(name="Loose").annotate(num_team_members=Count("team_members")) - self.assertEqual(root_ann.num_team_members, 0) - - async def test_prefetch_related_fk(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) - - event2 = await Event.filter(name="Test").prefetch_related("tournament") - self.assertEqual(event2[0].tournament, tournament) - - async def test_prefetch_related_rfk(self): - tournament = await Tournament.create(name="New Tournament") - event = await Event.create(name="Test", tournament_id=tournament.id) - - tournament2 = await Tournament.filter(name="New Tournament").prefetch_related("events") - self.assertEqual(list(tournament2[0].events), [event]) - - async def test_prefetch_related_missing_field(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) - - with self.assertRaisesRegex(FieldError, "Relation tourn1ment for models.Event not found"): - await Event.filter(name="Test").prefetch_related("tourn1ment") - - async def test_prefetch_related_nonrel_field(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) - - with self.assertRaisesRegex(FieldError, "Field modified on models.Event is not a relation"): - await Event.filter(name="Test").prefetch_related("modified") - - async def test_prefetch_related_id(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) - - with self.assertRaisesRegex(FieldError, "Field event_id on models.Event is not a relation"): - await Event.filter(name="Test").prefetch_related("event_id") - - async def test_nullable_fk_raw(self): - tournament = await Tournament.create(name="New Tournament") - reporter = await Reporter.create(name="Reporter") - event1 = await Event.create(name="Without reporter", tournament=tournament) - event2 = await Event.create(name="With reporter", tournament=tournament, reporter=reporter) - - self.assertFalse(event1.reporter_id) - self.assertTrue(event2.reporter_id) - - async def test_nullable_fk_obj(self): - tournament = await Tournament.create(name="New Tournament") - reporter = await Reporter.create(name="Reporter") - event1 = await Event.create(name="Without reporter", tournament=tournament) - event2 = await Event.create(name="With reporter", tournament=tournament, reporter=reporter) - - self.assertFalse(event1.reporter) - self.assertTrue(event2.reporter) - - async def test_db_constraint(self): - author = await Author.create(name="Some One") - book = await BookNoConstraint.create(name="First!", author=author, rating=4) - book = await BookNoConstraint.all().select_related("author").get(pk=book.pk) - self.assertEqual(author.pk, book.author.pk) - - async def test_select_related_with_annotation(self): - tournament = await Tournament.create(name="New Tournament") - reporter = await Reporter.create(name="Reporter") - event = await Event.create(name="With reporter", tournament=tournament, reporter=reporter) - event = ( - await Event.filter(pk=event.pk) - .select_related("reporter") - .annotate(tournament_name=Trim("tournament__name")) - .first() - ) - self.assertEqual(event.reporter, reporter) - self.assertTrue(hasattr(event, "tournament_name")) - self.assertEqual(event.tournament_name, tournament.name) - - async def test_select_related_sets_null_for_null_fk(self): - """Test that select related yields null for fields with nulled fk cols.""" - related_dude = await UUIDFkRelatedNullModel.create(name="Some model") - await related_dude.fetch_related("parent") # that is strange :) - related_dude_fresh = ( - await UUIDFkRelatedNullModel.all().select_related("parent").get(id=related_dude.id) - ) - self.assertIsNone(related_dude_fresh.parent) - self.assertEqual(related_dude_fresh.parent, related_dude.parent) - - async def test_select_related_sets_valid_nulls(self) -> None: - """When we select related objects, the data we get from db should be set to corresponding attribute.""" - left_2nd_lvl = await DoubleFK.create(name="second leaf") - left_1st_lvl = await DoubleFK.create(name="1st", left=left_2nd_lvl) - root = await DoubleFK.create(name="root", left=left_1st_lvl) - - retrieved_root = ( - await DoubleFK.all().select_related("left__left__left", "right").get(id=root.pk) - ) - self.assertIsNone(retrieved_root.right) - assert retrieved_root.left is not None - self.assertEqual(retrieved_root.left, left_1st_lvl) - self.assertEqual(retrieved_root.left.left, left_2nd_lvl) - - async def test_no_ambiguous_fk_relations_set(self): - """Basic select_related test cases provided by @https://github.com/Terrance. - - The idea was that on the moment of writing this feature, there were no way to correctly set attributes for - select_related fields attributes. - src: https://github.com/tortoise/tortoise-orm/pull/826#issuecomment-883341557 - """ - - extra = await Extra.create() - single = await Single.create(extra=extra) - await Pair.create(right=single) - pair = ( - await Pair.filter(id=1) - .select_related("left", "left__extra", "right", "right__extra") - .get() - ) - self.assertIsNone(pair.left) - self.assertEqual(pair.right.extra, extra) - single = await Single.create() - await Pair.create(right=single) - pair = ( - await Pair.filter(id=2) - .select_related("left", "left__extra", "right", "right__extra") - .get() - ) - self.assertIsNone(pair.right.extra) # should be None - - @test.requireCapability(dialect=NotIn("mssql", "mysql")) - async def test_0_value_fk(self): - """ForegnKeyField should exits even if the the source_field looks like false, but not None - src: https://github.com/tortoise/tortoise-orm/issues/1274 - """ - extra = await Extra.create(id=0) - single = await Single.create(extra=extra) - - single_reload = await Single.get(id=single.id) - assert (await single_reload.extra).id == 0 - - tournament_0 = await Tournament.create(name="tournament zero", id=0) - await Event.create(name="event-zero", tournament=tournament_0) - - e = await Event.get(name="event-zero") - id_before_fetch = e.tournament_id - await e.fetch_related("tournament") - id_after_fetch = e.tournament_id - self.assertEqual(id_before_fetch, id_after_fetch) - - event_0 = await Event.get(name="event-zero").prefetch_related("tournament") - self.assertEqual(event_0.tournament, tournament_0) - - -class TestDoubleFK(test.TestCase): - select_match = r'SELECT [`"]doublefk[`"].[`"]name[`"] [`"]name[`"]' - select1_match = r'[`"]doublefk__left[`"].[`"]name[`"] [`"]left__name[`"]' - select2_match = r'[`"]doublefk__right[`"].[`"]name[`"] [`"]right__name[`"]' - join1_match = ( - r'LEFT OUTER JOIN [`"]doublefk[`"] [`"]doublefk__left[`"] ON ' - r'[`"]doublefk__left[`"].[`"]id[`"]=[`"]doublefk[`"].[`"]left_id[`"]' + # Evaluated off creation objects + assert await loose.full_hierarchy__async_for() == LOOSE_TEXT + assert await loose.full_hierarchy__fetch_related() == LOOSE_TEXT + assert await root.full_hierarchy__async_for() == ROOT_TEXT + assert await root.full_hierarchy__fetch_related() == ROOT_TEXT + + # Evaluated off new objects -> Result is identical + root2 = await Employee.get(name="Root") + loose2 = await Employee.get(name="Loose") + assert await loose2.full_hierarchy__async_for() == LOOSE_TEXT + assert await loose2.full_hierarchy__fetch_related() == LOOSE_TEXT + assert await root2.full_hierarchy__async_for() == ROOT_TEXT + assert await root2.full_hierarchy__fetch_related() == ROOT_TEXT + + +@pytest.mark.asyncio +async def test_self_ref_filter_by_child(db): + root = await Employee.create(name="Root") + await Employee.create(name="1. First H1", manager=root) + await Employee.create(name="2. Second H1", manager=root) + + root2 = await Employee.get(team_members__name="1. First H1") + assert root.id == root2.id + + +@pytest.mark.asyncio +async def test_self_ref_filter_both(db): + root = await Employee.create(name="Root") + await Employee.create(name="1. First H1", manager=root) + await Employee.create(name="2. Second H1", manager=root) + + root2 = await Employee.get(name="Root", team_members__name="1. First H1") + assert root.id == root2.id + + +@pytest.mark.asyncio +async def test_self_ref_annotate(db): + root = await Employee.create(name="Root") + await Employee.create(name="Loose") + await Employee.create(name="1. First H1", manager=root) + await Employee.create(name="2. Second H1", manager=root) + + root_ann = await Employee.get(name="Root").annotate(num_team_members=Count("team_members")) + assert root_ann.num_team_members == 2 + root_ann = await Employee.get(name="Loose").annotate(num_team_members=Count("team_members")) + assert root_ann.num_team_members == 0 + + +@pytest.mark.asyncio +async def test_prefetch_related_fk(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) + + event2 = await Event.filter(name="Test").prefetch_related("tournament") + assert event2[0].tournament == tournament + + +@pytest.mark.asyncio +async def test_prefetch_related_rfk(db): + tournament = await Tournament.create(name="New Tournament") + event = await Event.create(name="Test", tournament_id=tournament.id) + + tournament2 = await Tournament.filter(name="New Tournament").prefetch_related("events") + assert list(tournament2[0].events) == [event] + + +@pytest.mark.asyncio +async def test_prefetch_related_missing_field(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) + + with pytest.raises(FieldError, match="Relation tourn1ment for models.Event not found"): + await Event.filter(name="Test").prefetch_related("tourn1ment") + + +@pytest.mark.asyncio +async def test_prefetch_related_nonrel_field(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) + + with pytest.raises(FieldError, match="Field modified on models.Event is not a relation"): + await Event.filter(name="Test").prefetch_related("modified") + + +@pytest.mark.asyncio +async def test_prefetch_related_id(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) + + with pytest.raises(FieldError, match="Field event_id on models.Event is not a relation"): + await Event.filter(name="Test").prefetch_related("event_id") + + +@pytest.mark.asyncio +async def test_nullable_fk_raw(db): + tournament = await Tournament.create(name="New Tournament") + reporter = await Reporter.create(name="Reporter") + event1 = await Event.create(name="Without reporter", tournament=tournament) + event2 = await Event.create(name="With reporter", tournament=tournament, reporter=reporter) + + assert not event1.reporter_id + assert event2.reporter_id + + +@pytest.mark.asyncio +async def test_nullable_fk_obj(db): + tournament = await Tournament.create(name="New Tournament") + reporter = await Reporter.create(name="Reporter") + event1 = await Event.create(name="Without reporter", tournament=tournament) + event2 = await Event.create(name="With reporter", tournament=tournament, reporter=reporter) + + assert not event1.reporter + assert event2.reporter + + +@pytest.mark.asyncio +async def test_db_constraint(db): + author = await Author.create(name="Some One") + book = await BookNoConstraint.create(name="First!", author=author, rating=4) + book = await BookNoConstraint.all().select_related("author").get(pk=book.pk) + assert author.pk == book.author.pk + + +@pytest.mark.asyncio +async def test_select_related_with_annotation(db): + tournament = await Tournament.create(name="New Tournament") + reporter = await Reporter.create(name="Reporter") + event = await Event.create(name="With reporter", tournament=tournament, reporter=reporter) + event = ( + await Event.filter(pk=event.pk) + .select_related("reporter") + .annotate(tournament_name=Trim("tournament__name")) + .first() + ) + assert event.reporter == reporter + assert hasattr(event, "tournament_name") + assert event.tournament_name == tournament.name + + +@pytest.mark.asyncio +async def test_select_related_sets_null_for_null_fk(db): + """Test that select related yields null for fields with nulled fk cols.""" + related_dude = await UUIDFkRelatedNullModel.create(name="Some model") + await related_dude.fetch_related("parent") # that is strange :) + related_dude_fresh = ( + await UUIDFkRelatedNullModel.all().select_related("parent").get(id=related_dude.id) + ) + assert related_dude_fresh.parent is None + assert related_dude_fresh.parent == related_dude.parent + + +@pytest.mark.asyncio +async def test_select_related_sets_valid_nulls(db) -> None: + """When we select related objects, the data we get from db should be set to corresponding attribute.""" + left_2nd_lvl = await DoubleFK.create(name="second leaf") + left_1st_lvl = await DoubleFK.create(name="1st", left=left_2nd_lvl) + root = await DoubleFK.create(name="root", left=left_1st_lvl) + + retrieved_root = ( + await DoubleFK.all().select_related("left__left__left", "right").get(id=root.pk) ) - join2_match = ( - r'LEFT OUTER JOIN [`"]doublefk[`"] [`"]doublefk__right[`"] ON ' - r'[`"]doublefk__right[`"].[`"]id[`"]=[`"]doublefk[`"].[`"]right_id[`"]' + assert retrieved_root.right is None + assert retrieved_root.left is not None + assert retrieved_root.left == left_1st_lvl + assert retrieved_root.left.left == left_2nd_lvl + + +@pytest.mark.asyncio +async def test_no_ambiguous_fk_relations_set(db): + """Basic select_related test cases provided by @https://github.com/Terrance. + + The idea was that on the moment of writing this feature, there were no way to correctly set attributes for + select_related fields attributes. + src: https://github.com/tortoise/tortoise-orm/pull/826#issuecomment-883341557 + """ + + extra = await Extra.create() + single = await Single.create(extra=extra) + await Pair.create(right=single) + pair = ( + await Pair.filter(id=1).select_related("left", "left__extra", "right", "right__extra").get() ) + assert pair.left is None + assert pair.right.extra == extra + single = await Single.create() + await Pair.create(right=single) + pair = ( + await Pair.filter(id=2).select_related("left", "left__extra", "right", "right__extra").get() + ) + assert pair.right.extra is None # should be None + + +@requireCapability(dialect=NotIn("mssql", "mysql")) +@pytest.mark.asyncio +async def test_0_value_fk(db): + """ForegnKeyField should exits even if the the source_field looks like false, but not None + src: https://github.com/tortoise/tortoise-orm/issues/1274 + """ + extra = await Extra.create(id=0) + single = await Single.create(extra=extra) + + single_reload = await Single.get(id=single.id) + assert (await single_reload.extra).id == 0 + + tournament_0 = await Tournament.create(name="tournament zero", id=0) + await Event.create(name="event-zero", tournament=tournament_0) + + e = await Event.get(name="event-zero") + id_before_fetch = e.tournament_id + await e.fetch_related("tournament") + id_after_fetch = e.tournament_id + assert id_before_fetch == id_after_fetch + + event_0 = await Event.get(name="event-zero").prefetch_related("tournament") + assert event_0.tournament == tournament_0 + + +# ============================================================================= +# TestDoubleFK - uses db fixture with setup data +# ============================================================================= + + +# Regex patterns for SQL query validation +_select_match = r'SELECT [`"]doublefk[`"].[`"]name[`"] [`"]name[`"]' +_select1_match = r'[`"]doublefk__left[`"].[`"]name[`"] [`"]left__name[`"]' +_select2_match = r'[`"]doublefk__right[`"].[`"]name[`"] [`"]right__name[`"]' +_join1_match = ( + r'LEFT OUTER JOIN [`"]doublefk[`"] [`"]doublefk__left[`"] ON ' + r'[`"]doublefk__left[`"].[`"]id[`"]=[`"]doublefk[`"].[`"]left_id[`"]' +) +_join2_match = ( + r'LEFT OUTER JOIN [`"]doublefk[`"] [`"]doublefk__right[`"] ON ' + r'[`"]doublefk__right[`"].[`"]id[`"]=[`"]doublefk[`"].[`"]right_id[`"]' +) + + +@pytest_asyncio.fixture +async def doublefk_data(db): + """Build DoubleFK test data.""" + one = await DoubleFK.create(name="one") + two = await DoubleFK.create(name="two") + middle = await DoubleFK.create(name="middle", left=one, right=two) + return middle + + +@pytest.mark.asyncio +async def test_doublefk_filter(db, doublefk_data): + middle = doublefk_data + qset = DoubleFK.filter(left__name="one") + result = await qset + query = qset.query.get_sql() + + assert re.search(_join1_match, query) + assert result == [middle] + + +@pytest.mark.asyncio +async def test_doublefk_filter_values(db, doublefk_data): + qset = DoubleFK.filter(left__name="one").values("name") + result = await qset + query = qset.query.get_sql() + + assert re.search(_select_match, query) + assert re.search(_join1_match, query) + assert result == [{"name": "middle"}] + + +@pytest.mark.asyncio +async def test_doublefk_filter_values_rel(db, doublefk_data): + qset = DoubleFK.filter(left__name="one").values("name", "left__name") + result = await qset + query = qset.query.get_sql() + + assert re.search(_select_match, query) + assert re.search(_select1_match, query) + assert re.search(_join1_match, query) + assert result == [{"name": "middle", "left__name": "one"}] + + +@pytest.mark.asyncio +async def test_doublefk_filter_both(db, doublefk_data): + middle = doublefk_data + qset = DoubleFK.filter(left__name="one", right__name="two") + result = await qset + query = qset.query.get_sql() + + assert re.search(_join1_match, query) + assert re.search(_join2_match, query) + assert result == [middle] + + +@pytest.mark.asyncio +async def test_doublefk_filter_both_values(db, doublefk_data): + qset = DoubleFK.filter(left__name="one", right__name="two").values("name") + result = await qset + query = qset.query.get_sql() + + assert re.search(_select_match, query) + assert re.search(_join1_match, query) + assert re.search(_join2_match, query) + assert result == [{"name": "middle"}] + + +@pytest.mark.asyncio +async def test_doublefk_filter_both_values_rel(db, doublefk_data): + qset = DoubleFK.filter(left__name="one", right__name="two").values( + "name", "left__name", "right__name" + ) + result = await qset + query = qset.query.get_sql() + + assert re.search(_select_match, query) + assert re.search(_select1_match, query) + assert re.search(_select2_match, query) + assert re.search(_join1_match, query) + assert re.search(_join2_match, query) + assert result == [{"name": "middle", "left__name": "one", "right__name": "two"}] + + +@pytest.mark.asyncio +async def test_many2many_field_with_o2o_fk(db): + tournament = await Tournament.create(name="t") + event = await Event.create(name="e", tournament=tournament) + address = await Address.create(city="c", street="s", event=event) + obj = await M2mWithO2oPk.create(name="m") + assert await obj.address.all() == [] + await obj.address.add(address) + assert await obj.address.all() == [address] + + +@pytest.mark.asyncio +async def test_o2o_fk_model_with_m2m_field(db): + author = await Author.create(name="a") + obj = await O2oPkModelWithM2m.create(author=author) + node = await Node.create(name="n") + assert await obj.nodes.all() == [] + await obj.nodes.add(node) + assert await obj.nodes.all() == [node] + + +@pytest.mark.asyncio +async def test_reverse_relation_create_fk(db): + tournament = await Tournament.create(name="Test Tournament") + assert await tournament.events.all() == [] + + event = await tournament.events.create(name="Test Event") + + await tournament.fetch_related("events") + + assert len(tournament.events) == 1 + assert event.name == "Test Event" + assert event.tournament_id == tournament.id + assert tournament.events[0].event_id == event.event_id + + +@pytest.mark.asyncio +async def test_reverse_relation_create_fk_errors_for_unsaved_instance(db): + tournament = Tournament(name="Unsaved Tournament") + + # Should raise OperationalError since tournament isn't saved + with pytest.raises(OperationalError) as cm: + await tournament.events.create(name="Test Event") + + assert "hasn't been instanced" in str(cm.value) + - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - one = await DoubleFK.create(name="one") - two = await DoubleFK.create(name="two") - self.middle = await DoubleFK.create(name="middle", left=one, right=two) - - async def test_doublefk_filter(self): - qset = DoubleFK.filter(left__name="one") - result = await qset - query = qset.query.get_sql() - - self.assertRegex(query, self.join1_match) - self.assertEqual(result, [self.middle]) - - async def test_doublefk_filter_values(self): - qset = DoubleFK.filter(left__name="one").values("name") - result = await qset - query = qset.query.get_sql() - - self.assertRegex(query, self.select_match) - self.assertRegex(query, self.join1_match) - self.assertEqual(result, [{"name": "middle"}]) - - async def test_doublefk_filter_values_rel(self): - qset = DoubleFK.filter(left__name="one").values("name", "left__name") - result = await qset - query = qset.query.get_sql() - - self.assertRegex(query, self.select_match) - self.assertRegex(query, self.select1_match) - self.assertRegex(query, self.join1_match) - self.assertEqual(result, [{"name": "middle", "left__name": "one"}]) - - async def test_doublefk_filter_both(self): - qset = DoubleFK.filter(left__name="one", right__name="two") - result = await qset - query = qset.query.get_sql() - - self.assertRegex(query, self.join1_match) - self.assertRegex(query, self.join2_match) - self.assertEqual(result, [self.middle]) - - async def test_doublefk_filter_both_values(self): - qset = DoubleFK.filter(left__name="one", right__name="two").values("name") - result = await qset - query = qset.query.get_sql() - - self.assertRegex(query, self.select_match) - self.assertRegex(query, self.join1_match) - self.assertRegex(query, self.join2_match) - self.assertEqual(result, [{"name": "middle"}]) - - async def test_doublefk_filter_both_values_rel(self): - qset = DoubleFK.filter(left__name="one", right__name="two").values( - "name", "left__name", "right__name" - ) - result = await qset - query = qset.query.get_sql() - - self.assertRegex(query, self.select_match) - self.assertRegex(query, self.select1_match) - self.assertRegex(query, self.select2_match) - self.assertRegex(query, self.join1_match) - self.assertRegex(query, self.join2_match) - self.assertEqual(result, [{"name": "middle", "left__name": "one", "right__name": "two"}]) - - async def test_many2many_field_with_o2o_fk(self): - tournament = await Tournament.create(name="t") - event = await Event.create(name="e", tournament=tournament) - address = await Address.create(city="c", street="s", event=event) - obj = await M2mWithO2oPk.create(name="m") - self.assertEqual(await obj.address.all(), []) - await obj.address.add(address) - self.assertEqual(await obj.address.all(), [address]) - - async def test_o2o_fk_model_with_m2m_field(self): - author = await Author.create(name="a") - obj = await O2oPkModelWithM2m.create(author=author) - node = await Node.create(name="n") - self.assertEqual(await obj.nodes.all(), []) - await obj.nodes.add(node) - self.assertEqual(await obj.nodes.all(), [node]) - - async def test_reverse_relation_create_fk(self): - tournament = await Tournament.create(name="Test Tournament") - self.assertEqual(await tournament.events.all(), []) - - event = await tournament.events.create(name="Test Event") - - await tournament.fetch_related("events") - - self.assertEqual(len(tournament.events), 1) - self.assertEqual(event.name, "Test Event") - self.assertEqual(event.tournament_id, tournament.id) - self.assertEqual(tournament.events[0].event_id, event.event_id) - - async def test_reverse_relation_create_fk_errors_for_unsaved_instance(self): - tournament = Tournament(name="Unsaved Tournament") - - # Should raise OperationalError since tournament isn't saved - with self.assertRaises(OperationalError) as cm: - await tournament.events.create(name="Test Event") - - self.assertIn("hasn't been instanced", str(cm.exception)) - - @test.requireCapability(dialect="sqlite") - async def test_recursive(self) -> None: - file = "examples/relations_recursive.py" - r = subprocess.run(["python", file], capture_output=True, text=True) # nosec - assert not r.stderr - output = r.stdout - s = "2.1. Second H2 (to: ) (from: 2.2. Third H2, Loose, 1.1. First H2)" - self.assertIn(s, output) +@requireCapability(dialect="sqlite") +@pytest.mark.asyncio +async def test_recursive(db) -> None: + file = "examples/relations_recursive.py" + r = subprocess.run([sys.executable, file], capture_output=True, text=True) # nosec + assert not r.stderr, f"Script had errors: {r.stderr}" + output = r.stdout + s = "2.1. Second H2 (to: ) (from: 2.2. Third H2, Loose, 1.1. First H2)" + assert s in output diff --git a/tests/test_relations_with_unique.py b/tests/test_relations_with_unique.py index 2a0f4ca4a..99b980db8 100644 --- a/tests/test_relations_with_unique.py +++ b/tests/test_relations_with_unique.py @@ -1,47 +1,46 @@ +import pytest + from tests.testmodels import Principal, School, Student -from tortoise.contrib import test from tortoise.query_utils import Prefetch -class TestRelationsWithUnique(test.TestCase): - async def test_relation_with_unique(self): - school1 = await School.create(id=1024, name="School1") - student1 = await Student.create(name="Sang-Heon Jeon1", school_id=school1.id) - - student_schools = await Student.filter(name="Sang-Heon Jeon1").values( - "name", "school__name" - ) - self.assertEqual(student_schools[0], {"name": "Sang-Heon Jeon1", "school__name": "School1"}) - student_schools = await Student.all().values(school="school__name") - self.assertEqual(student_schools[0]["school"], school1.name) - student_schools = await Student.all().values_list("school__name") - self.assertEqual(student_schools[0][0], school1.name) - - await Student.create(name="Sang-Heon Jeon2", school=school1) - school_with_filtered = ( - await School.all() - .prefetch_related(Prefetch("students", queryset=Student.filter(name="Sang-Heon Jeon1"))) - .first() - ) - school_without_filtered = await School.first().prefetch_related("students") - self.assertEqual(len(school_with_filtered.students), 1) - self.assertEqual(len(school_without_filtered.students), 2) - - student_direct_prefetch = await Student.first().prefetch_related("school") - self.assertEqual(student_direct_prefetch.school.id, school1.id) - - school2 = await School.create(id=2048, name="School2") - await Student.all().update(school=school2) - student = await Student.first() - self.assertEqual(student.school_id, school2.id) - - await Student.filter(id=student1.id).update(school=school1) - schools = await School.all().order_by("students__name") - self.assertEqual([school.name for school in schools], ["School1", "School2"]) - schools = await School.all().order_by("-students__name") - self.assertEqual([school.name for school in schools], ["School2", "School1"]) - - fetched_principal = await Principal.create(name="Sang-Heon Jeon3", school=school1) - self.assertEqual(fetched_principal.name, "Sang-Heon Jeon3") - fetched_school = await School.filter(name="School1").prefetch_related("principal").first() - self.assertEqual(fetched_school.name, "School1") +@pytest.mark.asyncio +async def test_relation_with_unique(db): + school1 = await School.create(id=1024, name="School1") + student1 = await Student.create(name="Sang-Heon Jeon1", school_id=school1.id) + + student_schools = await Student.filter(name="Sang-Heon Jeon1").values("name", "school__name") + assert student_schools[0] == {"name": "Sang-Heon Jeon1", "school__name": "School1"} + student_schools = await Student.all().values(school="school__name") + assert student_schools[0]["school"] == school1.name + student_schools = await Student.all().values_list("school__name") + assert student_schools[0][0] == school1.name + + await Student.create(name="Sang-Heon Jeon2", school=school1) + school_with_filtered = ( + await School.all() + .prefetch_related(Prefetch("students", queryset=Student.filter(name="Sang-Heon Jeon1"))) + .first() + ) + school_without_filtered = await School.first().prefetch_related("students") + assert len(school_with_filtered.students) == 1 + assert len(school_without_filtered.students) == 2 + + student_direct_prefetch = await Student.first().prefetch_related("school") + assert student_direct_prefetch.school.id == school1.id + + school2 = await School.create(id=2048, name="School2") + await Student.all().update(school=school2) + student = await Student.first() + assert student.school_id == school2.id + + await Student.filter(id=student1.id).update(school=school1) + schools = await School.all().order_by("students__name") + assert [school.name for school in schools] == ["School1", "School2"] + schools = await School.all().order_by("-students__name") + assert [school.name for school in schools] == ["School2", "School1"] + + fetched_principal = await Principal.create(name="Sang-Heon Jeon3", school=school1) + assert fetched_principal.name == "Sang-Heon Jeon3" + fetched_school = await School.filter(name="School1").prefetch_related("principal").first() + assert fetched_school.name == "School1" diff --git a/tests/test_signals.py b/tests/test_signals.py index 75f209f11..732ff1b36 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,8 +1,10 @@ from __future__ import annotations +import pytest +import pytest_asyncio + from tests.testmodels import Signals from tortoise import BaseDBAsyncClient -from tortoise.contrib import test from tortoise.signals import post_delete, post_save, pre_delete, pre_save @@ -40,43 +42,60 @@ async def signal_post_delete( await Signals.filter(name="test4").update(name="test_post-delete") -class TestSignals(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - self.signal_save = await Signals.create(name="signal_save") - self.signal_delete = await Signals.create(name="signal_delete") +@pytest_asyncio.fixture +async def signals_data(db): + """Set up test data for signal tests.""" + signal_save = await Signals.create(name="signal_save") + signal_delete = await Signals.create(name="signal_delete") + + signal1 = await Signals.create(name="test1") + signal2 = await Signals.create(name="test2") + signal3 = await Signals.create(name="test3") + signal4 = await Signals.create(name="test4") + signal5 = await Signals.create(name="test5") + signal6 = await Signals.create(name="test6") + + return { + "signal_save": signal_save, + "signal_delete": signal_delete, + "signal1": signal1, + "signal2": signal2, + "signal3": signal3, + "signal4": signal4, + "signal5": signal5, + "signal6": signal6, + } + + +@pytest.mark.asyncio +async def test_create(signals_data): + await Signals.create(name="test-create") + signal5 = await Signals.get(pk=signals_data["signal5"].pk) + signal6 = await Signals.get(pk=signals_data["signal6"].pk) + assert signal5.name == "test_pre-save" + assert signal6.name == "test_post-save" - self.signal1 = await Signals.create(name="test1") - self.signal2 = await Signals.create(name="test2") - self.signal3 = await Signals.create(name="test3") - self.signal4 = await Signals.create(name="test4") - self.signal5 = await Signals.create(name="test5") - self.signal6 = await Signals.create(name="test6") - async def test_create(self): - await Signals.create(name="test-create") - signal5 = await Signals.get(pk=self.signal5.pk) - signal6 = await Signals.get(pk=self.signal6.pk) - self.assertEqual(signal5.name, "test_pre-save") - self.assertEqual(signal6.name, "test_post-save") +@pytest.mark.asyncio +async def test_save(signals_data): + signal_save = await Signals.get(pk=signals_data["signal_save"].pk) + signal_save.name = "test-save" + await signal_save.save() - async def test_save(self): - signal_save = await Signals.get(pk=self.signal_save.pk) - signal_save.name = "test-save" - await signal_save.save() + signal1 = await Signals.get(pk=signals_data["signal1"].pk) + signal2 = await Signals.get(pk=signals_data["signal2"].pk) - signal1 = await Signals.get(pk=self.signal1.pk) - signal2 = await Signals.get(pk=self.signal2.pk) + assert signal1.name == "test_pre-save" + assert signal2.name == "test_post-save" - self.assertEqual(signal1.name, "test_pre-save") - self.assertEqual(signal2.name, "test_post-save") - async def test_delete(self): - signal_delete = await Signals.get(pk=self.signal_delete.pk) - await signal_delete.delete() +@pytest.mark.asyncio +async def test_delete(signals_data): + signal_delete = await Signals.get(pk=signals_data["signal_delete"].pk) + await signal_delete.delete() - signal3 = await Signals.get(pk=self.signal3.pk) - signal4 = await Signals.get(pk=self.signal4.pk) + signal3 = await Signals.get(pk=signals_data["signal3"].pk) + signal4 = await Signals.get(pk=signals_data["signal4"].pk) - self.assertEqual(signal3.name, "test_pre-delete") - self.assertEqual(signal4.name, "test_post-delete") + assert signal3.name == "test_pre-delete" + assert signal4.name == "test_post-delete" diff --git a/tests/test_source_field.py b/tests/test_source_field.py index d0c19c2e6..1e0f14322 100644 --- a/tests/test_source_field.py +++ b/tests/test_source_field.py @@ -5,6 +5,8 @@ This is to test that behaviour doesn't change when one defined source_field parameters. """ +import pytest + from tests.testmodels import NumberSourceField, SourceFields, StraightFields from tortoise.contrib import test from tortoise.contrib.test.condition import NotEQ @@ -12,64 +14,78 @@ from tortoise.functions import Coalesce, Count, Length, Lower, Trim, Upper -class StraightFieldTests(test.TestCase): - def setUp(self) -> None: - self.model = StraightFields +# Helper function to sort model instances by pk +def sort_by_pk(items): + return sorted(items, key=lambda x: x.pk) + + +class TestStraightFields: + """Tests for StraightFields model.""" + + model = StraightFields - async def test_get_all(self): + @pytest.mark.asyncio + async def test_get_all(self, db): obj1 = await self.model.create(chars="aaa") - self.assertIsNotNone(obj1.eyedee, str(dir(obj1))) + assert obj1.eyedee is not None, str(dir(obj1)) obj2 = await self.model.create(chars="bbb") objs = await self.model.all() - self.assertListSortEqual(objs, [obj1, obj2]) + assert sort_by_pk(objs) == sort_by_pk([obj1, obj2]) - async def test_get_by_pk(self): + @pytest.mark.asyncio + async def test_get_by_pk(self, db): obj = await self.model.create(chars="aaa") obj1 = await self.model.get(eyedee=obj.eyedee) - self.assertEqual(obj, obj1) + assert obj == obj1 - async def test_get_by_chars(self): + @pytest.mark.asyncio + async def test_get_by_chars(self, db): obj = await self.model.create(chars="aaa") obj1 = await self.model.get(chars="aaa") - self.assertEqual(obj, obj1) + assert obj == obj1 - async def test_get_fk_forward_fetch_related(self): + @pytest.mark.asyncio + async def test_get_fk_forward_fetch_related(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj2a = await self.model.get(eyedee=obj2.eyedee) await obj2a.fetch_related("fk") - self.assertEqual(obj2, obj2a) - self.assertEqual(obj1, obj2a.fk) + assert obj2 == obj2a + assert obj1 == obj2a.fk - async def test_get_fk_forward_prefetch_related(self): + @pytest.mark.asyncio + async def test_get_fk_forward_prefetch_related(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj2a = await self.model.get(eyedee=obj2.eyedee).prefetch_related("fk") - self.assertEqual(obj2, obj2a) - self.assertEqual(obj1, obj2a.fk) + assert obj2 == obj2a + assert obj1 == obj2a.fk - async def test_get_fk_reverse_await(self): + @pytest.mark.asyncio + async def test_get_fk_reverse_await(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj3 = await self.model.create(chars="ccc", fk=obj1) obj1a = await self.model.get(eyedee=obj1.eyedee) - self.assertListSortEqual(await obj1a.fkrev, [obj2, obj3]) + assert sort_by_pk(await obj1a.fkrev) == sort_by_pk([obj2, obj3]) - async def test_get_fk_reverse_filter(self): + @pytest.mark.asyncio + async def test_get_fk_reverse_filter(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj3 = await self.model.create(chars="ccc", fk=obj1) objs = await self.model.filter(fk=obj1) - self.assertListSortEqual(objs, [obj2, obj3]) + assert sort_by_pk(objs) == sort_by_pk([obj2, obj3]) - async def test_get_fk_reverse_async_for(self): + @pytest.mark.asyncio + async def test_get_fk_reverse_async_for(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj3 = await self.model.create(chars="ccc", fk=obj1) @@ -78,112 +94,125 @@ async def test_get_fk_reverse_async_for(self): objs = [] async for obj in obj1a.fkrev: objs.append(obj) - self.assertListSortEqual(objs, [obj2, obj3]) + assert sort_by_pk(objs) == sort_by_pk([obj2, obj3]) - async def test_get_fk_reverse_fetch_related(self): + @pytest.mark.asyncio + async def test_get_fk_reverse_fetch_related(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj3 = await self.model.create(chars="ccc", fk=obj1) obj1a = await self.model.get(eyedee=obj1.eyedee) await obj1a.fetch_related("fkrev") - self.assertListSortEqual(list(obj1a.fkrev), [obj2, obj3]) + assert sort_by_pk(list(obj1a.fkrev)) == sort_by_pk([obj2, obj3]) - async def test_get_fk_reverse_prefetch_related(self): + @pytest.mark.asyncio + async def test_get_fk_reverse_prefetch_related(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj3 = await self.model.create(chars="ccc", fk=obj1) obj1a = await self.model.get(eyedee=obj1.eyedee).prefetch_related("fkrev") - self.assertListSortEqual(list(obj1a.fkrev), [obj2, obj3]) + assert sort_by_pk(list(obj1a.fkrev)) == sort_by_pk([obj2, obj3]) - async def test_get_m2m_forward_await(self): + @pytest.mark.asyncio + async def test_get_m2m_forward_await(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj1.rel_to.add(obj2) obj2a = await self.model.get(eyedee=obj2.eyedee) - self.assertEqual(await obj2a.rel_from, [obj1]) + assert await obj2a.rel_from == [obj1] obj1a = await self.model.get(eyedee=obj1.eyedee) - self.assertEqual(await obj1a.rel_to, [obj2]) + assert await obj1a.rel_to == [obj2] - async def test_get_m2m_reverse_await(self): + @pytest.mark.asyncio + async def test_get_m2m_reverse_await(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj2.rel_from.add(obj1) obj2a = await self.model.get(pk=obj2.eyedee) - self.assertEqual(await obj2a.rel_from, [obj1]) + assert await obj2a.rel_from == [obj1] obj1a = await self.model.get(eyedee=obj1.pk) - self.assertEqual(await obj1a.rel_to, [obj2]) + assert await obj1a.rel_to == [obj2] - async def test_get_m2m_filter(self): + @pytest.mark.asyncio + async def test_get_m2m_filter(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj1.rel_to.add(obj2) rel_froms = await self.model.filter(rel_from=obj1) - self.assertEqual(rel_froms, [obj2]) + assert rel_froms == [obj2] rel_tos = await self.model.filter(rel_to=obj2) - self.assertEqual(rel_tos, [obj1]) + assert rel_tos == [obj1] - async def test_get_m2m_forward_fetch_related(self): + @pytest.mark.asyncio + async def test_get_m2m_forward_fetch_related(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj1.rel_to.add(obj2) obj2a = await self.model.get(eyedee=obj2.eyedee) await obj2a.fetch_related("rel_from") - self.assertEqual(list(obj2a.rel_from), [obj1]) + assert list(obj2a.rel_from) == [obj1] - async def test_get_m2m_reverse_fetch_related(self): + @pytest.mark.asyncio + async def test_get_m2m_reverse_fetch_related(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj1.rel_to.add(obj2) obj1a = await self.model.get(eyedee=obj1.eyedee) await obj1a.fetch_related("rel_to") - self.assertEqual(list(obj1a.rel_to), [obj2]) + assert list(obj1a.rel_to) == [obj2] - async def test_get_m2m_forward_prefetch_related(self): + @pytest.mark.asyncio + async def test_get_m2m_forward_prefetch_related(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj1.rel_to.add(obj2) obj2a = await self.model.get(eyedee=obj2.eyedee).prefetch_related("rel_from") - self.assertEqual(list(obj2a.rel_from), [obj1]) + assert list(obj2a.rel_from) == [obj1] - async def test_get_m2m_reverse_prefetch_related(self): + @pytest.mark.asyncio + async def test_get_m2m_reverse_prefetch_related(self, db): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj1.rel_to.add(obj2) obj1a = await self.model.get(eyedee=obj1.eyedee).prefetch_related("rel_to") - self.assertEqual(list(obj1a.rel_to), [obj2]) + assert list(obj1a.rel_to) == [obj2] - async def test_values_reverse_relation(self): + @pytest.mark.asyncio + async def test_values_reverse_relation(self, db): obj1 = await self.model.create(chars="aaa") await self.model.create(chars="bbb", fk=obj1) obj1a = await self.model.filter(chars="aaa").values("fkrev__chars") - self.assertEqual(obj1a[0]["fkrev__chars"], "bbb") + assert obj1a[0]["fkrev__chars"] == "bbb" - async def test_f_expression(self): + @pytest.mark.asyncio + async def test_f_expression(self, db): obj1 = await self.model.create(chars="aaa") await self.model.filter(eyedee=obj1.eyedee).update(chars=F("blip")) obj2 = await self.model.get(eyedee=obj1.eyedee) - self.assertEqual(obj2.chars, "BLIP") + assert obj2.chars == "BLIP" - async def test_function(self): + @pytest.mark.asyncio + async def test_function(self, db): obj1 = await self.model.create(chars=" aaa ") await self.model.filter(eyedee=obj1.eyedee).update(chars=Trim("chars")) obj2 = await self.model.get(eyedee=obj1.eyedee) - self.assertEqual(obj2.chars, "aaa") + assert obj2.chars == "aaa" - async def test_aggregation_with_filter(self): + @pytest.mark.asyncio + async def test_aggregation_with_filter(self, db): obj1 = await self.model.create(chars="aaa") await self.model.create(chars="bbb", fk=obj1) await self.model.create(chars="ccc", fk=obj1) @@ -198,88 +227,98 @@ async def test_aggregation_with_filter(self): .first() ) - self.assertEqual(obj.all, 2) - self.assertEqual(obj.one, 1) - self.assertEqual(obj.no, 0) + assert obj.all == 2 + assert obj.one == 1 + assert obj.no == 0 - async def test_filter_by_aggregation_field_coalesce(self): + @pytest.mark.asyncio + async def test_filter_by_aggregation_field_coalesce(self, db): await self.model.create(chars="aaa", nullable="null") await self.model.create(chars="bbb") objs = await self.model.annotate(null=Coalesce("nullable", "null")).filter(null="null") - self.assertEqual(len(objs), 2) - self.assertSetEqual({(o.chars, o.null) for o in objs}, {("aaa", "null"), ("bbb", "null")}) + assert len(objs) == 2 + assert {(o.chars, o.null) for o in objs} == {("aaa", "null"), ("bbb", "null")} - async def test_filter_by_aggregation_field_count(self): + @pytest.mark.asyncio + async def test_filter_by_aggregation_field_count(self, db): await self.model.create(chars="aaa") await self.model.create(chars="bbb") obj = await self.model.annotate(chars_count=Count("chars")).filter( chars_count=1, chars="aaa" ) - self.assertEqual(len(obj), 1) - self.assertEqual(obj[0].chars, "aaa") + assert len(obj) == 1 + assert obj[0].chars == "aaa" @test.requireCapability(dialect=NotEQ("mssql")) - async def test_filter_by_aggregation_field_length(self): + @pytest.mark.asyncio + async def test_filter_by_aggregation_field_length(self, db): await self.model.create(chars="aaa") await self.model.create(chars="bbbbb") obj = await self.model.annotate(chars_length=Length("chars")).filter(chars_length=3) - self.assertEqual(len(obj), 1) - self.assertEqual(obj[0].chars_length, 3) + assert len(obj) == 1 + assert obj[0].chars_length == 3 - async def test_filter_by_aggregation_field_lower(self): + @pytest.mark.asyncio + async def test_filter_by_aggregation_field_lower(self, db): await self.model.create(chars="AaA") obj = await self.model.annotate(chars_lower=Lower("chars")).filter(chars_lower="aaa") - self.assertEqual(len(obj), 1) - self.assertEqual(obj[0].chars_lower, "aaa") + assert len(obj) == 1 + assert obj[0].chars_lower == "aaa" - async def test_filter_by_aggregation_field_trim(self): + @pytest.mark.asyncio + async def test_filter_by_aggregation_field_trim(self, db): await self.model.create(chars=" aaa ") obj = await self.model.annotate(chars_trim=Trim("chars")).filter(chars_trim="aaa") - self.assertEqual(len(obj), 1) - self.assertEqual(obj[0].chars_trim, "aaa") + assert len(obj) == 1 + assert obj[0].chars_trim == "aaa" - async def test_filter_by_aggregation_field_upper(self): + @pytest.mark.asyncio + async def test_filter_by_aggregation_field_upper(self, db): await self.model.create(chars="aAa") obj = await self.model.annotate(chars_upper=Upper("chars")).filter(chars_upper="AAA") - self.assertEqual(len(obj), 1) - self.assertEqual(obj[0].chars_upper, "AAA") + assert len(obj) == 1 + assert obj[0].chars_upper == "AAA" - async def test_values_by_fk(self): + @pytest.mark.asyncio + async def test_values_by_fk(self, db): obj1 = await self.model.create(chars="aaa") await self.model.create(chars="bbb", fk=obj1) obj = await self.model.filter(chars="bbb").values("fk__chars") - self.assertEqual(obj, [{"fk__chars": "aaa"}]) + assert obj == [{"fk__chars": "aaa"}] - async def test_filter_with_field_f(self): + @pytest.mark.asyncio + async def test_filter_with_field_f(self, db): obj = await self.model.create(chars="a") ret_obj = await self.model.filter(eyedee=F("eyedee")).first() - self.assertEqual(obj, ret_obj) + assert obj == ret_obj ret_obj = await self.model.filter(eyedee__lt=F("eyedee") + 1).first() - self.assertEqual(obj, ret_obj) + assert obj == ret_obj - async def test_filter_with_field_f_annotation(self): + @pytest.mark.asyncio + async def test_filter_with_field_f_annotation(self, db): obj = await self.model.create(chars="a") ret_obj = ( await self.model.annotate(eyedee_a=F("eyedee")).filter(eyedee=F("eyedee_a")).first() ) - self.assertEqual(obj, ret_obj) + assert obj == ret_obj ret_obj = ( await self.model.annotate(eyedee_a=F("eyedee") + 1) .filter(eyedee__lt=F("eyedee_a")) .first() ) - self.assertEqual(obj, ret_obj) + assert obj == ret_obj - async def test_group_by(self): + @pytest.mark.asyncio + async def test_group_by(self, db): await self.model.create(chars="aaa", blip="a") await self.model.create(chars="aaa", blip="b") await self.model.create(chars="bbb") @@ -290,26 +329,28 @@ async def test_group_by(self): .order_by("chars") .values("chars", "chars_count") ) - self.assertEqual( - objs, [{"chars": "aaa", "chars_count": 2}, {"chars": "bbb", "chars_count": 1}] - ) + assert objs == [{"chars": "aaa", "chars_count": 2}, {"chars": "bbb", "chars_count": 1}] + + +class TestSourceFields(TestStraightFields): + """Tests for SourceFields model (same tests as StraightFields).""" + model = SourceFields # type: ignore[assignment] -class SourceFieldTests(StraightFieldTests): - def setUp(self) -> None: - self.model = SourceFields # type: ignore +class TestNumberSourceField: + """Tests for NumberSourceField model.""" -class NumberSourceFieldTests(test.TestCase): - def setUp(self) -> None: - self.model = NumberSourceField + model = NumberSourceField - async def test_f_expression_save(self): + @pytest.mark.asyncio + async def test_f_expression_save(self, db): obj1 = await self.model.create() obj1.number = F("number") + 1 await obj1.save() - async def test_f_expression_save_update_fields(self): + @pytest.mark.asyncio + async def test_f_expression_save_update_fields(self, db): obj1 = await self.model.create() obj1.number = F("number") + 1 await obj1.save(update_fields=["number"]) diff --git a/tests/test_sql.py b/tests/test_sql.py index fca792a45..3db519515 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,239 +1,270 @@ +import re + +import pytest + from tests.testmodels import CharPkModel, Event, IntFields from tortoise import connections from tortoise.backends.psycopg.client import PsycopgClient -from tortoise.contrib import test from tortoise.expressions import F from tortoise.functions import Coalesce, Concat -class TestSQL(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - self.db = connections.get("models") - self.dialect = self.db.schema_generator.DIALECT - self.is_psycopg = isinstance(self.db, PsycopgClient) - - def test_filter(self): - sql = CharPkModel.all().filter(id="123").sql() - if self.dialect == "mysql": - expected = "SELECT `id` FROM `charpkmodel` WHERE `id`=%s" - elif self.dialect == "postgres": - if self.is_psycopg: - expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=%s' - else: - expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=$1' - else: - expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=?' - - self.assertEqual(sql, expected) - - def test_filter_with_limit_offset(self): - sql = CharPkModel.all().filter(id="123").limit(10).offset(0).sql() - if self.dialect == "mysql": - expected = "SELECT `id` FROM `charpkmodel` WHERE `id`=%s LIMIT %s OFFSET %s" - elif self.dialect == "postgres": - if self.is_psycopg: - expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=%s LIMIT %s OFFSET %s' - else: - expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=$1 LIMIT $2 OFFSET $3' - elif self.dialect == "mssql": - expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=? ORDER BY (SELECT 0) OFFSET ? ROWS FETCH NEXT ? ROWS ONLY' - else: - expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=? LIMIT ? OFFSET ?' +@pytest.fixture +def sql_context(db): + """Fixture providing database connection, dialect and psycopg flag.""" + db_conn = connections.get("models") + dialect = db_conn.schema_generator.DIALECT + is_psycopg = isinstance(db_conn, PsycopgClient) + return db_conn, dialect, is_psycopg - self.assertEqual(sql, expected) - def test_group_by(self): - sql = IntFields.all().group_by("intnum").values("intnum").sql() - if self.dialect == "mysql": - expected = "SELECT `intnum` `intnum` FROM `intfields` GROUP BY `intnum`" - else: - expected = 'SELECT "intnum" "intnum" FROM "intfields" GROUP BY "intnum"' - self.assertEqual(sql, expected) - - def test_annotate(self): - sql = CharPkModel.all().annotate(id_plus_one=Concat(F("id"), "_postfix")).sql() - if self.dialect == "mysql": - expected = "SELECT `id`,CONCAT(`id`,%s) `id_plus_one` FROM `charpkmodel`" - elif self.dialect == "postgres": - if self.is_psycopg: - expected = ( - 'SELECT "id",CONCAT("id"::text,%s::text) "id_plus_one" FROM "charpkmodel"' - ) - else: - expected = ( - 'SELECT "id",CONCAT("id"::text,$1::text) "id_plus_one" FROM "charpkmodel"' - ) +def test_filter(sql_context): + db, dialect, is_psycopg = sql_context + sql = CharPkModel.all().filter(id="123").sql() + if dialect == "mysql": + expected = "SELECT `id` FROM `charpkmodel` WHERE `id`=%s" + elif dialect == "postgres": + if is_psycopg: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=%s' else: - expected = 'SELECT "id",CONCAT("id",?) "id_plus_one" FROM "charpkmodel"' - self.assertEqual(sql, expected) - - def test_annotate_concat_fields(self): - sql = CharPkModel.all().annotate(id_double=Concat(F("id"), F("id"))).sql() - if self.dialect == "mysql": - expected = "SELECT `id`,CONCAT(`id`,`id`) `id_double` FROM `charpkmodel`" - elif self.dialect == "postgres": - expected = 'SELECT "id",CONCAT("id"::text,"id"::text) "id_double" FROM "charpkmodel"' + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=$1' + else: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=?' + + assert sql == expected + + +def test_filter_with_limit_offset(sql_context): + db, dialect, is_psycopg = sql_context + sql = CharPkModel.all().filter(id="123").limit(10).offset(0).sql() + if dialect == "mysql": + expected = "SELECT `id` FROM `charpkmodel` WHERE `id`=%s LIMIT %s OFFSET %s" + elif dialect == "postgres": + if is_psycopg: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=%s LIMIT %s OFFSET %s' else: - expected = 'SELECT "id",CONCAT("id","id") "id_double" FROM "charpkmodel"' - self.assertEqual(sql, expected) - - def test_annotate_coalesce_field_expression(self): - sql = IntFields.all().annotate(num=Coalesce("intnum", F("intnum_null"))).values("num").sql() - if self.dialect == "mysql": - expected = "SELECT COALESCE(`intnum`,`intnum_null`) `num` FROM `intfields`" - elif self.dialect == "postgres": - expected = 'SELECT COALESCE("intnum","intnum_null") "num" FROM "intfields"' + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=$1 LIMIT $2 OFFSET $3' + elif dialect == "mssql": + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=? ORDER BY (SELECT 0) OFFSET ? ROWS FETCH NEXT ? ROWS ONLY' + else: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=? LIMIT ? OFFSET ?' + + assert sql == expected + + +def test_group_by(sql_context): + db, dialect, is_psycopg = sql_context + sql = IntFields.all().group_by("intnum").values("intnum").sql() + if dialect == "mysql": + expected = "SELECT `intnum` `intnum` FROM `intfields` GROUP BY `intnum`" + else: + expected = 'SELECT "intnum" "intnum" FROM "intfields" GROUP BY "intnum"' + assert sql == expected + + +def test_annotate(sql_context): + db, dialect, is_psycopg = sql_context + sql = CharPkModel.all().annotate(id_plus_one=Concat(F("id"), "_postfix")).sql() + if dialect == "mysql": + expected = "SELECT `id`,CONCAT(`id`,%s) `id_plus_one` FROM `charpkmodel`" + elif dialect == "postgres": + if is_psycopg: + expected = 'SELECT "id",CONCAT("id"::text,%s::text) "id_plus_one" FROM "charpkmodel"' else: - expected = 'SELECT COALESCE("intnum","intnum_null") "num" FROM "intfields"' - self.assertEqual(sql, expected) - - def test_annotate_function_join_expression(self): - qset = ( - Event.all() - .annotate(full_name=Concat("name", F("tournament__name"))) - .values("full_name") - ) - sql = qset.sql() - join_match = ( - r'LEFT OUTER JOIN [`"]tournament[`"] [`"]event__tournament[`"] ON ' - r'[`"]event__tournament[`"]\.[`"]id[`"]=[`"]event[`"]\.[`"]tournament_id[`"]' - ) - self.assertRegex(sql, join_match) - concat_match = ( - r"CONCAT\(`?event`?\.`?name`?(?:::text)?\s*,\s*`?event__tournament`?\.`?name`?" - r"(?:::text)?\)" - r'|CONCAT\("event"\."name"(?:::text)?\s*,\s*"event__tournament"\."name"' - r"(?:::text)?\)" - ) - self.assertRegex(sql, concat_match) - - def test_values(self): - sql = IntFields.filter(intnum=1).values("intnum").sql() - if self.dialect == "mysql": - expected = "SELECT `intnum` `intnum` FROM `intfields` WHERE `intnum`=%s" - elif self.dialect == "postgres": - if self.is_psycopg: - expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=%s' - else: - expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=$1' + expected = 'SELECT "id",CONCAT("id"::text,$1::text) "id_plus_one" FROM "charpkmodel"' + else: + expected = 'SELECT "id",CONCAT("id",?) "id_plus_one" FROM "charpkmodel"' + assert sql == expected + + +def test_annotate_concat_fields(sql_context): + db, dialect, is_psycopg = sql_context + sql = CharPkModel.all().annotate(id_double=Concat(F("id"), F("id"))).sql() + if dialect == "mysql": + expected = "SELECT `id`,CONCAT(`id`,`id`) `id_double` FROM `charpkmodel`" + elif dialect == "postgres": + expected = 'SELECT "id",CONCAT("id"::text,"id"::text) "id_double" FROM "charpkmodel"' + else: + expected = 'SELECT "id",CONCAT("id","id") "id_double" FROM "charpkmodel"' + assert sql == expected + + +def test_annotate_coalesce_field_expression(sql_context): + db, dialect, is_psycopg = sql_context + sql = IntFields.all().annotate(num=Coalesce("intnum", F("intnum_null"))).values("num").sql() + if dialect == "mysql": + expected = "SELECT COALESCE(`intnum`,`intnum_null`) `num` FROM `intfields`" + elif dialect == "postgres": + expected = 'SELECT COALESCE("intnum","intnum_null") "num" FROM "intfields"' + else: + expected = 'SELECT COALESCE("intnum","intnum_null") "num" FROM "intfields"' + assert sql == expected + + +def test_annotate_function_join_expression(sql_context): + db, dialect, is_psycopg = sql_context + qset = Event.all().annotate(full_name=Concat("name", F("tournament__name"))).values("full_name") + sql = qset.sql() + join_match = ( + r'LEFT OUTER JOIN [`"]tournament[`"] [`"]event__tournament[`"] ON ' + r'[`"]event__tournament[`"]\.[`"]id[`"]=[`"]event[`"]\.[`"]tournament_id[`"]' + ) + assert re.search(join_match, sql) + concat_match = ( + r"CONCAT\(`?event`?\.`?name`?(?:::text)?\s*,\s*`?event__tournament`?\.`?name`?" + r"(?:::text)?\)" + r'|CONCAT\("event"\."name"(?:::text)?\s*,\s*"event__tournament"\."name"' + r"(?:::text)?\)" + ) + assert re.search(concat_match, sql) + + +def test_values(sql_context): + db, dialect, is_psycopg = sql_context + sql = IntFields.filter(intnum=1).values("intnum").sql() + if dialect == "mysql": + expected = "SELECT `intnum` `intnum` FROM `intfields` WHERE `intnum`=%s" + elif dialect == "postgres": + if is_psycopg: + expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=%s' else: - expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=?' - self.assertEqual(sql, expected) - - def test_values_list(self): - sql = IntFields.filter(intnum=1).values_list("intnum").sql() - if self.dialect == "mysql": - expected = "SELECT `intnum` `0` FROM `intfields` WHERE `intnum`=%s" - elif self.dialect == "postgres": - if self.is_psycopg: - expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=%s' - else: - expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=$1' + expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=$1' + else: + expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=?' + assert sql == expected + + +def test_values_list(sql_context): + db, dialect, is_psycopg = sql_context + sql = IntFields.filter(intnum=1).values_list("intnum").sql() + if dialect == "mysql": + expected = "SELECT `intnum` `0` FROM `intfields` WHERE `intnum`=%s" + elif dialect == "postgres": + if is_psycopg: + expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=%s' else: - expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=?' - self.assertEqual(sql, expected) - - def test_exists(self): - sql = IntFields.filter(intnum=1).exists().sql() - if self.dialect == "mysql": - expected = "SELECT 1 FROM `intfields` WHERE `intnum`=%s LIMIT %s" - elif self.dialect == "postgres": - if self.is_psycopg: - expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=%s LIMIT %s' - else: - expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=$1 LIMIT $2' - elif self.dialect == "mssql": - expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=? ORDER BY (SELECT 0) OFFSET 0 ROWS FETCH NEXT ? ROWS ONLY' + expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=$1' + else: + expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=?' + assert sql == expected + + +def test_exists(sql_context): + db, dialect, is_psycopg = sql_context + sql = IntFields.filter(intnum=1).exists().sql() + if dialect == "mysql": + expected = "SELECT 1 FROM `intfields` WHERE `intnum`=%s LIMIT %s" + elif dialect == "postgres": + if is_psycopg: + expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=%s LIMIT %s' else: - expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=? LIMIT ?' - self.assertEqual(sql, expected) - - def test_count(self): - sql = IntFields.all().filter(intnum=1).count().sql() - if self.dialect == "mysql": - expected = "SELECT COUNT(*) FROM `intfields` WHERE `intnum`=%s" - elif self.dialect == "postgres": - if self.is_psycopg: - expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=%s' - else: - expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=$1' + expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=$1 LIMIT $2' + elif dialect == "mssql": + expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=? ORDER BY (SELECT 0) OFFSET 0 ROWS FETCH NEXT ? ROWS ONLY' + else: + expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=? LIMIT ?' + assert sql == expected + + +def test_count(sql_context): + db, dialect, is_psycopg = sql_context + sql = IntFields.all().filter(intnum=1).count().sql() + if dialect == "mysql": + expected = "SELECT COUNT(*) FROM `intfields` WHERE `intnum`=%s" + elif dialect == "postgres": + if is_psycopg: + expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=%s' else: - expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=?' - self.assertEqual(sql, expected) - - def test_update(self): - sql = IntFields.filter(intnum=2).update(intnum=1).sql() - if self.dialect == "mysql": - expected = "UPDATE `intfields` SET `intnum`=%s WHERE `intnum`=%s" - elif self.dialect == "postgres": - if self.is_psycopg: - expected = 'UPDATE "intfields" SET "intnum"=%s WHERE "intnum"=%s' - else: - expected = 'UPDATE "intfields" SET "intnum"=$1 WHERE "intnum"=$2' + expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=$1' + else: + expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=?' + assert sql == expected + + +def test_update(sql_context): + db, dialect, is_psycopg = sql_context + sql = IntFields.filter(intnum=2).update(intnum=1).sql() + if dialect == "mysql": + expected = "UPDATE `intfields` SET `intnum`=%s WHERE `intnum`=%s" + elif dialect == "postgres": + if is_psycopg: + expected = 'UPDATE "intfields" SET "intnum"=%s WHERE "intnum"=%s' else: - expected = 'UPDATE "intfields" SET "intnum"=? WHERE "intnum"=?' - self.assertEqual(sql, expected) - - def test_delete(self): - sql = IntFields.filter(intnum=2).delete().sql() - if self.dialect == "mysql": - expected = "DELETE FROM `intfields` WHERE `intnum`=%s" - elif self.dialect == "postgres": - if self.is_psycopg: - expected = 'DELETE FROM "intfields" WHERE "intnum"=%s' - else: - expected = 'DELETE FROM "intfields" WHERE "intnum"=$1' + expected = 'UPDATE "intfields" SET "intnum"=$1 WHERE "intnum"=$2' + else: + expected = 'UPDATE "intfields" SET "intnum"=? WHERE "intnum"=?' + assert sql == expected + + +def test_delete(sql_context): + db, dialect, is_psycopg = sql_context + sql = IntFields.filter(intnum=2).delete().sql() + if dialect == "mysql": + expected = "DELETE FROM `intfields` WHERE `intnum`=%s" + elif dialect == "postgres": + if is_psycopg: + expected = 'DELETE FROM "intfields" WHERE "intnum"=%s' else: - expected = 'DELETE FROM "intfields" WHERE "intnum"=?' - self.assertEqual(sql, expected) - - async def test_bulk_update(self): - obj1 = await IntFields.create(intnum=1) - obj2 = await IntFields.create(intnum=2) - obj1.intnum = obj1.intnum + 1 - obj2.intnum = obj2.intnum + 1 - sql = IntFields.bulk_update([obj1, obj2], fields=["intnum"]).sql() - - if self.dialect == "mysql": - expected = "UPDATE `intfields` SET `intnum`=CASE WHEN `id`=%s THEN %s WHEN `id`=%s THEN %s END WHERE `id` IN (%s,%s)" - elif self.dialect == "postgres": - if self.is_psycopg: - expected = 'UPDATE "intfields" SET "intnum"=CASE WHEN "id"=%s THEN CAST(%s AS INT) WHEN "id"=%s THEN CAST(%s AS INT) END WHERE "id" IN (%s,%s)' - else: - expected = 'UPDATE "intfields" SET "intnum"=CASE WHEN "id"=$1 THEN CAST($2 AS INT) WHEN "id"=$3 THEN CAST($4 AS INT) END WHERE "id" IN ($5,$6)' + expected = 'DELETE FROM "intfields" WHERE "intnum"=$1' + else: + expected = 'DELETE FROM "intfields" WHERE "intnum"=?' + assert sql == expected + + +@pytest.mark.asyncio +async def test_bulk_update(sql_context): + db, dialect, is_psycopg = sql_context + obj1 = await IntFields.create(intnum=1) + obj2 = await IntFields.create(intnum=2) + obj1.intnum = obj1.intnum + 1 + obj2.intnum = obj2.intnum + 1 + sql = IntFields.bulk_update([obj1, obj2], fields=["intnum"]).sql() + + if dialect == "mysql": + expected = "UPDATE `intfields` SET `intnum`=CASE WHEN `id`=%s THEN %s WHEN `id`=%s THEN %s END WHERE `id` IN (%s,%s)" + elif dialect == "postgres": + if is_psycopg: + expected = 'UPDATE "intfields" SET "intnum"=CASE WHEN "id"=%s THEN CAST(%s AS INT) WHEN "id"=%s THEN CAST(%s AS INT) END WHERE "id" IN (%s,%s)' else: - expected = 'UPDATE "intfields" SET "intnum"=CASE WHEN "id"=? THEN ? WHEN "id"=? THEN ? END WHERE "id" IN (?,?)' - self.assertEqual(sql, expected) - - async def test_bulk_create_autogenerated_pk(self): - sql = IntFields.bulk_create( - [IntFields(intnum=1, intnum_null=2), IntFields(intnum=3, intnum_null=4)] - ).sql() - if self.dialect == "mysql": - expected = "INSERT INTO `intfields` (`intnum`,`intnum_null`) VALUES (%s,%s)" - elif self.dialect == "postgres": - if self.is_psycopg: - expected = ( - 'INSERT INTO "intfields" ("intnum","intnum_null") VALUES (%s,%s) RETURNING "id"' - ) - else: - expected = ( - 'INSERT INTO "intfields" ("intnum","intnum_null") VALUES ($1,$2) RETURNING "id"' - ) + expected = 'UPDATE "intfields" SET "intnum"=CASE WHEN "id"=$1 THEN CAST($2 AS INT) WHEN "id"=$3 THEN CAST($4 AS INT) END WHERE "id" IN ($5,$6)' + else: + expected = 'UPDATE "intfields" SET "intnum"=CASE WHEN "id"=? THEN ? WHEN "id"=? THEN ? END WHERE "id" IN (?,?)' + assert sql == expected + + +@pytest.mark.asyncio +async def test_bulk_create_autogenerated_pk(sql_context): + db, dialect, is_psycopg = sql_context + sql = IntFields.bulk_create( + [IntFields(intnum=1, intnum_null=2), IntFields(intnum=3, intnum_null=4)] + ).sql() + if dialect == "mysql": + expected = "INSERT INTO `intfields` (`intnum`,`intnum_null`) VALUES (%s,%s)" + elif dialect == "postgres": + if is_psycopg: + expected = ( + 'INSERT INTO "intfields" ("intnum","intnum_null") VALUES (%s,%s) RETURNING "id"' + ) else: - expected = 'INSERT INTO "intfields" ("intnum","intnum_null") VALUES (?,?)' - self.assertEqual(sql, expected) - - async def test_bulk_create_specified_pk(self): - sql = IntFields.bulk_create([IntFields(id=1, intnum=1), IntFields(id=2, intnum=2)]).sql() - if self.dialect == "mysql": - expected = "INSERT INTO `intfields` (`id`,`intnum`,`intnum_null`) VALUES (%s,%s,%s)" - elif self.dialect == "postgres": - if self.is_psycopg: - expected = 'INSERT INTO "intfields" ("id","intnum","intnum_null") VALUES (%s,%s,%s)' - else: - expected = 'INSERT INTO "intfields" ("id","intnum","intnum_null") VALUES ($1,$2,$3)' + expected = ( + 'INSERT INTO "intfields" ("intnum","intnum_null") VALUES ($1,$2) RETURNING "id"' + ) + else: + expected = 'INSERT INTO "intfields" ("intnum","intnum_null") VALUES (?,?)' + assert sql == expected + + +@pytest.mark.asyncio +async def test_bulk_create_specified_pk(sql_context): + db, dialect, is_psycopg = sql_context + sql = IntFields.bulk_create([IntFields(id=1, intnum=1), IntFields(id=2, intnum=2)]).sql() + if dialect == "mysql": + expected = "INSERT INTO `intfields` (`id`,`intnum`,`intnum_null`) VALUES (%s,%s,%s)" + elif dialect == "postgres": + if is_psycopg: + expected = 'INSERT INTO "intfields" ("id","intnum","intnum_null") VALUES (%s,%s,%s)' else: - expected = 'INSERT INTO "intfields" ("id","intnum","intnum_null") VALUES (?,?,?)' - self.assertEqual(sql, expected) + expected = 'INSERT INTO "intfields" ("id","intnum","intnum_null") VALUES ($1,$2,$3)' + else: + expected = 'INSERT INTO "intfields" ("id","intnum","intnum_null") VALUES (?,?,?)' + assert sql == expected diff --git a/tests/test_table_name.py b/tests/test_table_name.py index 802ccbd75..62ea9e24c 100644 --- a/tests/test_table_name.py +++ b/tests/test_table_name.py @@ -1,5 +1,8 @@ -from tortoise import Tortoise, fields -from tortoise.contrib.test import SimpleTestCase +import pytest +import pytest_asyncio + +from tortoise import fields +from tortoise.context import TortoiseContext from tortoise.models import Model @@ -21,22 +24,25 @@ class Meta: table = "my_custom_table" -class TestTableNameGenerator(SimpleTestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - await Tortoise.init( +@pytest_asyncio.fixture +async def table_name_db(): + """Fixture for table name generator tests with in-memory SQLite.""" + ctx = TortoiseContext() + async with ctx: + await ctx.init( db_url="sqlite://:memory:", modules={"models": [__name__]}, table_name_generator=table_name_generator, ) - await Tortoise.generate_schemas() + await ctx.generate_schemas() + yield ctx + - async def test_glabal_name_generator(self): - self.assertEqual(Tournament._meta.db_table, "test_tournament") +@pytest.mark.asyncio +async def test_glabal_name_generator(table_name_db): + assert Tournament._meta.db_table == "test_tournament" - async def test_custom_table_name_precedence(self): - self.assertEqual(CustomTable._meta.db_table, "my_custom_table") - async def _tearDownDB(self) -> None: - # Explicitly close aiosqlite connection to fix ResourceWarning - await Tortoise.get_connection("default").close() +@pytest.mark.asyncio +async def test_custom_table_name_precedence(table_name_db): + assert CustomTable._meta.db_table == "my_custom_table" diff --git a/tests/test_transactions.py b/tests/test_transactions.py index c44945360..789f14259 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -1,8 +1,10 @@ from unittest.mock import Mock +import pytest + from tests.testmodels import CharPkModel, Event, Team, Tournament from tortoise import connections -from tortoise.contrib import test +from tortoise.contrib.test import requireCapability from tortoise.exceptions import OperationalError, TransactionManagementError from tortoise.transactions import atomic, in_transaction @@ -20,320 +22,406 @@ async def atomic_decorated_func(): return tournament -@test.requireCapability(supports_transactions=True) -class TestTransactions(test.IsolatedTestCase): - """This test case uses IsolatedTestCase to ensure that - - there is no open transaction before the test starts - - commits in these tests do not impact other tests - """ +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_transactions(db_isolated): + """Test basic transaction rollback on exception.""" + with pytest.raises(SomeException): + async with in_transaction(): + tournament = Tournament(name="Test") + await tournament.save() + await Tournament.filter(id=tournament.id).update(name="Updated name") + saved_event = await Tournament.filter(name="Updated name").first() + assert saved_event.id == tournament.id + raise SomeException("Some error") - async def test_transactions(self): - with self.assertRaises(SomeException): - async with in_transaction(): - tournament = Tournament(name="Test") - await tournament.save() - await Tournament.filter(id=tournament.id).update(name="Updated name") - saved_event = await Tournament.filter(name="Updated name").first() - self.assertEqual(saved_event.id, tournament.id) - raise SomeException("Some error") + saved_event = await Tournament.filter(name="Updated name").first() + assert saved_event is None - saved_event = await Tournament.filter(name="Updated name").first() - self.assertIsNone(saved_event) - async def test_get_or_create_transaction_using_db(self): - async with in_transaction() as connection: - obj = await CharPkModel.get_or_create(id="FooMip", using_db=connection) - self.assertIsNotNone(obj) - await connection.rollback() +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_get_or_create_transaction_using_db(db_isolated): + """Test get_or_create with explicit connection rollback.""" + async with in_transaction() as connection: + obj = await CharPkModel.get_or_create(id="FooMip", using_db=connection) + assert obj is not None + await connection.rollback() - obj2 = await CharPkModel.filter(id="FooMip").first() - self.assertIsNone(obj2) + obj2 = await CharPkModel.filter(id="FooMip").first() + assert obj2 is None - async def test_consequent_nested_transactions(self): + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_consequent_nested_transactions(db_isolated): + """Test consequent nested transactions.""" + async with in_transaction(): + await Tournament.create(name="Test") async with in_transaction(): - await Tournament.create(name="Test") - async with in_transaction(): - await Tournament.create(name="Nested 1") - await Tournament.create(name="Test 2") + await Tournament.create(name="Nested 1") + await Tournament.create(name="Test 2") + async with in_transaction(): + await Tournament.create(name="Nested 2") + + assert set(await Tournament.all().values_list("name", flat=True)) == { + "Test", + "Nested 1", + "Test 2", + "Nested 2", + } + + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_caught_exception_in_nested_transaction(db_isolated): + """Test that caught exception in nested transaction only rolls back inner.""" + async with in_transaction(): + tournament = await Tournament.create(name="Test") + await Tournament.filter(id=tournament.id).update(name="Updated name") + saved_event = await Tournament.filter(name="Updated name").first() + assert saved_event.id == tournament.id + with pytest.raises(SomeException): async with in_transaction(): - await Tournament.create(name="Nested 2") + tournament = await Tournament.create(name="Nested") + saved_tournament = await Tournament.filter(name="Nested").first() + assert tournament.id == saved_tournament.id + raise SomeException("Some error") - self.assertEqual( - set(await Tournament.all().values_list("name", flat=True)), - set(["Test", "Nested 1", "Test 2", "Nested 2"]), - ) + saved_event = await Tournament.filter(name="Updated name").first() + assert saved_event is not None + not_saved_event = await Tournament.filter(name="Nested").first() + assert not_saved_event is None - async def test_caught_exception_in_nested_transaction(self): + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_nested_tx_do_not_commit(db_isolated): + """Test that nested transactions don't commit if outer fails.""" + with pytest.raises(SomeException): async with in_transaction(): tournament = await Tournament.create(name="Test") - await Tournament.filter(id=tournament.id).update(name="Updated name") - saved_event = await Tournament.filter(name="Updated name").first() - self.assertEqual(saved_event.id, tournament.id) - with self.assertRaises(SomeException): - async with in_transaction(): - tournament = await Tournament.create(name="Nested") - saved_tournament = await Tournament.filter(name="Nested").first() - self.assertEqual(tournament.id, saved_tournament.id) - raise SomeException("Some error") + async with in_transaction(): + tournament.name = "Nested" + await tournament.save() - saved_event = await Tournament.filter(name="Updated name").first() - self.assertIsNotNone(saved_event) - not_saved_event = await Tournament.filter(name="Nested").first() - self.assertIsNone(not_saved_event) + raise SomeException("Some error") - async def test_nested_tx_do_not_commit(self): - with self.assertRaises(SomeException): - async with in_transaction(): - tournament = await Tournament.create(name="Test") + assert await Tournament.filter(id=tournament.id).count() == 0 + + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_nested_rollback_does_not_enable_autocommit(db_isolated): + """Test that nested rollback doesn't enable autocommit.""" + with pytest.raises(SomeException, match="Error 2"): + async with in_transaction(): + await Tournament.create(name="Test1") + with pytest.raises(SomeException, match="Error 1"): async with in_transaction(): - tournament.name = "Nested" - await tournament.save() + await Tournament.create(name="Test2") + raise SomeException("Error 1") - raise SomeException("Some error") + await Tournament.create(name="Test3") + raise SomeException("Error 2") - self.assertEqual(await Tournament.filter(id=tournament.id).count(), 0) + assert await Tournament.all().count() == 0 - async def test_nested_rollback_does_not_enable_autocommit(self): - with self.assertRaisesRegex(SomeException, "Error 2"): + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_nested_savepoint_rollbacks(db_isolated): + """Test nested savepoint rollbacks.""" + async with in_transaction(): + await Tournament.create(name="Outer Transaction 1") + + with pytest.raises(SomeException, match="Inner 1"): async with in_transaction(): - await Tournament.create(name="Test1") - with self.assertRaisesRegex(SomeException, "Error 1"): - async with in_transaction(): - await Tournament.create(name="Test2") - raise SomeException("Error 1") + await Tournament.create(name="Inner 1") + raise SomeException("Inner 1") - await Tournament.create(name="Test3") - raise SomeException("Error 2") + await Tournament.create(name="Outer Transaction 2") - self.assertEqual(await Tournament.all().count(), 0) + with pytest.raises(SomeException, match="Inner 2"): + async with in_transaction(): + await Tournament.create(name="Inner 2") + raise SomeException("Inner 2") - async def test_nested_savepoint_rollbacks(self): - async with in_transaction(): - await Tournament.create(name="Outer Transaction 1") + await Tournament.create(name="Outer Transaction 3") - with self.assertRaisesRegex(SomeException, "Inner 1"): - async with in_transaction(): - await Tournament.create(name="Inner 1") - raise SomeException("Inner 1") + assert await Tournament.all().values_list("name", flat=True) == [ + "Outer Transaction 1", + "Outer Transaction 2", + "Outer Transaction 3", + ] - await Tournament.create(name="Outer Transaction 2") - with self.assertRaisesRegex(SomeException, "Inner 2"): - async with in_transaction(): - await Tournament.create(name="Inner 2") - raise SomeException("Inner 2") +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_nested_savepoint_rollback_but_other_succeed(db_isolated): + """Test nested savepoint rollback while other nested transactions succeed.""" + async with in_transaction(): + await Tournament.create(name="Outer Transaction 1") - await Tournament.create(name="Outer Transaction 3") + with pytest.raises(SomeException, match="Inner 1"): + async with in_transaction(): + await Tournament.create(name="Inner 1") + raise SomeException("Inner 1") - self.assertEqual( - await Tournament.all().values_list("name", flat=True), - ["Outer Transaction 1", "Outer Transaction 2", "Outer Transaction 3"], - ) + await Tournament.create(name="Outer Transaction 2") - async def test_nested_savepoint_rollback_but_other_succeed(self): async with in_transaction(): - await Tournament.create(name="Outer Transaction 1") + await Tournament.create(name="Inner 2") - with self.assertRaisesRegex(SomeException, "Inner 1"): - async with in_transaction(): - await Tournament.create(name="Inner 1") - raise SomeException("Inner 1") + await Tournament.create(name="Outer Transaction 3") - await Tournament.create(name="Outer Transaction 2") + assert await Tournament.all().values_list("name", flat=True) == [ + "Outer Transaction 1", + "Outer Transaction 2", + "Inner 2", + "Outer Transaction 3", + ] + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_three_nested_transactions(db_isolated): + """Test three levels of nested transactions.""" + async with in_transaction(): + tournament1 = await Tournament.create(name="Test") + async with in_transaction(): + tournament2 = await Tournament.create(name="Nested") async with in_transaction(): - await Tournament.create(name="Inner 2") + tournament3 = await Tournament.create(name="Nested2") - await Tournament.create(name="Outer Transaction 3") + assert ( + await Tournament.filter(id__in=[tournament1.id, tournament2.id, tournament3.id]).count() + == 3 + ) - self.assertEqual( - await Tournament.all().values_list("name", flat=True), - ["Outer Transaction 1", "Outer Transaction 2", "Inner 2", "Outer Transaction 3"], - ) - async def test_three_nested_transactions(self): - async with in_transaction(): - tournament1 = await Tournament.create(name="Test") - async with in_transaction(): - tournament2 = await Tournament.create(name="Nested") - async with in_transaction(): - tournament3 = await Tournament.create(name="Nested2") - - self.assertEqual( - await Tournament.filter( - id__in=[tournament1.id, tournament2.id, tournament3.id] - ).count(), - 3, - ) - - async def test_transaction_decorator(self): - @atomic() - async def bound_to_succeed(): - tournament = Tournament(name="Test") - await tournament.save() - await Tournament.filter(id=tournament.id).update(name="Updated name") - saved_event = await Tournament.filter(name="Updated name").first() - self.assertEqual(saved_event.id, tournament.id) - return tournament +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_transaction_decorator(db_isolated): + """Test @atomic decorator with successful transaction.""" - tournament = await bound_to_succeed() + @atomic() + async def bound_to_succeed(): + tournament = Tournament(name="Test") + await tournament.save() + await Tournament.filter(id=tournament.id).update(name="Updated name") saved_event = await Tournament.filter(name="Updated name").first() - self.assertEqual(saved_event.id, tournament.id) + assert saved_event.id == tournament.id + return tournament - async def test_transaction_decorator_defined_before_init(self): - tournament = await atomic_decorated_func() - saved_event = await Tournament.filter(name="Test").first() - self.assertEqual(saved_event.id, tournament.id) + tournament = await bound_to_succeed() + saved_event = await Tournament.filter(name="Updated name").first() + assert saved_event.id == tournament.id - async def test_transaction_decorator_fail(self): - tournament = await Tournament.create(name="Test") - @atomic() - async def bound_to_fall(): - saved_event = await Tournament.filter(name="Test").first() - self.assertEqual(saved_event.id, tournament.id) - await Tournament.filter(id=tournament.id).update(name="Updated name") - saved_event = await Tournament.filter(name="Updated name").first() - self.assertEqual(saved_event.id, tournament.id) - raise OperationalError() +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_transaction_decorator_defined_before_init(db_isolated): + """Test @atomic decorator defined before Tortoise init.""" + tournament = await atomic_decorated_func() + saved_event = await Tournament.filter(name="Test").first() + assert saved_event.id == tournament.id + - with self.assertRaises(OperationalError): - await bound_to_fall() +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_transaction_decorator_fail(db_isolated): + """Test @atomic decorator with failing transaction.""" + tournament = await Tournament.create(name="Test") + + @atomic() + async def bound_to_fall(): saved_event = await Tournament.filter(name="Test").first() - self.assertEqual(saved_event.id, tournament.id) + assert saved_event.id == tournament.id + await Tournament.filter(id=tournament.id).update(name="Updated name") saved_event = await Tournament.filter(name="Updated name").first() - self.assertIsNone(saved_event) + assert saved_event.id == tournament.id + raise OperationalError() + + with pytest.raises(OperationalError): + await bound_to_fall() + saved_event = await Tournament.filter(name="Test").first() + assert saved_event.id == tournament.id + saved_event = await Tournament.filter(name="Updated name").first() + assert saved_event is None + + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_transaction_with_m2m_relations(db_isolated): + """Test transaction with M2M relations.""" + async with in_transaction(): + tournament = await Tournament.create(name="Test") + event = await Event.create(name="Test event", tournament=tournament) + team = await Team.create(name="Test team") + await event.participants.add(team) - async def test_transaction_with_m2m_relations(self): - async with in_transaction(): - tournament = await Tournament.create(name="Test") - event = await Event.create(name="Test event", tournament=tournament) - team = await Team.create(name="Test team") - await event.participants.add(team) - - async def test_transaction_exception_1(self): - with self.assertRaises(TransactionManagementError): - async with in_transaction() as connection: - await connection.rollback() - await connection.rollback() - - async def test_transaction_exception_2(self): - with self.assertRaises(TransactionManagementError): - async with in_transaction() as connection: - await connection.commit() - await connection.commit() - - async def test_insert_await_across_transaction_fail(self): - tournament = Tournament(name="Test") - query = tournament.save() # pylint: disable=E1111 - try: - async with in_transaction(): - await query - raise KeyError("moo") - except KeyError: - pass +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_transaction_exception_1(db_isolated): + """Test double rollback raises TransactionManagementError.""" + with pytest.raises(TransactionManagementError): + async with in_transaction() as connection: + await connection.rollback() + await connection.rollback() - self.assertEqual(await Tournament.all(), []) - async def test_insert_await_across_transaction_success(self): - tournament = Tournament(name="Test") - query = tournament.save() # pylint: disable=E1111 +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_transaction_exception_2(db_isolated): + """Test double commit raises TransactionManagementError.""" + with pytest.raises(TransactionManagementError): + async with in_transaction() as connection: + await connection.commit() + await connection.commit() + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_insert_await_across_transaction_fail(db_isolated): + """Test insert await across transaction that fails.""" + tournament = Tournament(name="Test") + query = tournament.save() # pylint: disable=E1111 + + try: async with in_transaction(): await query + raise KeyError("moo") + except KeyError: + pass - self.assertEqual(await Tournament.all(), [tournament]) + assert await Tournament.all() == [] - async def test_update_await_across_transaction_fail(self): - obj = await Tournament.create(name="Test1") - query = Tournament.filter(id=obj.id).update(name="Test2") - try: - async with in_transaction(): - await query - raise KeyError("moo") - except KeyError: - pass +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_insert_await_across_transaction_success(db_isolated): + """Test insert await across transaction that succeeds.""" + tournament = Tournament(name="Test") + query = tournament.save() # pylint: disable=E1111 - self.assertEqual( - await Tournament.all().values("id", "name"), [{"id": obj.id, "name": "Test1"}] - ) + async with in_transaction(): + await query - async def test_update_await_across_transaction_success(self): - obj = await Tournament.create(name="Test1") + assert await Tournament.all() == [tournament] - query = Tournament.filter(id=obj.id).update(name="Test2") + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_update_await_across_transaction_fail(db_isolated): + """Test update await across transaction that fails.""" + obj = await Tournament.create(name="Test1") + + query = Tournament.filter(id=obj.id).update(name="Test2") + try: async with in_transaction(): await query + raise KeyError("moo") + except KeyError: + pass - self.assertEqual( - await Tournament.all().values("id", "name"), [{"id": obj.id, "name": "Test2"}] - ) + assert await Tournament.all().values("id", "name") == [{"id": obj.id, "name": "Test1"}] - async def test_delete_await_across_transaction_fail(self): - obj = await Tournament.create(name="Test1") - query = Tournament.filter(id=obj.id).delete() - try: - async with in_transaction(): - await query - raise KeyError("moo") - except KeyError: - pass +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_update_await_across_transaction_success(db_isolated): + """Test update await across transaction that succeeds.""" + obj = await Tournament.create(name="Test1") - self.assertEqual( - await Tournament.all().values("id", "name"), [{"id": obj.id, "name": "Test1"}] - ) + query = Tournament.filter(id=obj.id).update(name="Test2") + async with in_transaction(): + await query + + assert await Tournament.all().values("id", "name") == [{"id": obj.id, "name": "Test2"}] - async def test_delete_await_across_transaction_success(self): - obj = await Tournament.create(name="Test1") - query = Tournament.filter(id=obj.id).delete() +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_delete_await_across_transaction_fail(db_isolated): + """Test delete await across transaction that fails.""" + obj = await Tournament.create(name="Test1") + + query = Tournament.filter(id=obj.id).delete() + try: async with in_transaction(): await query + raise KeyError("moo") + except KeyError: + pass - self.assertEqual(await Tournament.all(), []) + assert await Tournament.all().values("id", "name") == [{"id": obj.id, "name": "Test1"}] - async def test_select_await_across_transaction_fail(self): - try: - async with in_transaction(): - query = Tournament.all().values("name") - await Tournament.create(name="Test1") - result = await query - raise KeyError("moo") - except KeyError: - pass - self.assertEqual(result, [{"name": "Test1"}]) - self.assertEqual(await Tournament.all(), []) +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_delete_await_across_transaction_success(db_isolated): + """Test delete await across transaction that succeeds.""" + obj = await Tournament.create(name="Test1") + + query = Tournament.filter(id=obj.id).delete() + async with in_transaction(): + await query - async def test_select_await_across_transaction_success(self): + assert await Tournament.all() == [] + + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_select_await_across_transaction_fail(db_isolated): + """Test select await across transaction that fails.""" + try: async with in_transaction(): - query = Tournament.all().values("id", "name") - obj = await Tournament.create(name="Test1") + query = Tournament.all().values("name") + await Tournament.create(name="Test1") result = await query + raise KeyError("moo") + except KeyError: + pass + + assert result == [{"name": "Test1"}] + assert await Tournament.all() == [] + + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_select_await_across_transaction_success(db_isolated): + """Test select await across transaction that succeeds.""" + async with in_transaction(): + query = Tournament.all().values("id", "name") + obj = await Tournament.create(name="Test1") + result = await query + + assert result == [{"id": obj.id, "name": "Test1"}] + assert await Tournament.all().values("id", "name") == [{"id": obj.id, "name": "Test1"}] + + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_rollback_raising_exception(db_isolated): + """Tests that if a rollback raises an exception, the connection context is restored.""" + conn = connections.get("models") + with pytest.raises(ValueError, match="rollback"): + async with conn._in_transaction() as tx_conn: + tx_conn.rollback = Mock(side_effect=ValueError("rollback")) + raise ValueError("initial exception") + + assert connections.get("models") == conn + + +@requireCapability(supports_transactions=True) +@pytest.mark.asyncio +async def test_commit_raising_exception(db_isolated): + """Tests that if a commit raises an exception, the connection context is restored.""" + conn = connections.get("models") + with pytest.raises(ValueError, match="commit"): + async with conn._in_transaction() as tx_conn: + tx_conn.commit = Mock(side_effect=ValueError("commit")) - self.assertEqual(result, [{"id": obj.id, "name": "Test1"}]) - self.assertEqual( - await Tournament.all().values("id", "name"), [{"id": obj.id, "name": "Test1"}] - ) - - async def test_rollback_raising_exception(self): - """Tests that if a rollback raises an exception, the connection context is restored.""" - conn = connections.get("models") - with self.assertRaisesRegex(ValueError, "rollback"): - async with conn._in_transaction() as tx_conn: - tx_conn.rollback = Mock(side_effect=ValueError("rollback")) - raise ValueError("initial exception") - - self.assertEqual(connections.get("models"), conn) - - async def test_commit_raising_exception(self): - """Tests that if a commit raises an exception, the connection context is restored.""" - conn = connections.get("models") - with self.assertRaisesRegex(ValueError, "commit"): - async with conn._in_transaction() as tx_conn: - tx_conn.commit = Mock(side_effect=ValueError("commit")) - - self.assertEqual(connections.get("models"), conn) + assert connections.get("models") == conn diff --git a/tests/test_two_databases.py b/tests/test_two_databases.py index 0aec81be9..23a2b6a06 100644 --- a/tests/test_two_databases.py +++ b/tests/test_two_databases.py @@ -1,93 +1,130 @@ +import os + +import pytest +import pytest_asyncio + from tests.testmodels import Event, EventTwo, TeamTwo, Tournament -from tortoise import Tortoise, connections -from tortoise.backends.oracle import OracleClient -from tortoise.contrib import test +from tortoise.context import TortoiseContext from tortoise.exceptions import OperationalError, ParamsError from tortoise.transactions import in_transaction +# Optional import for Oracle client that requires system dependencies +try: + from tortoise.backends.oracle import OracleClient +except ImportError: + OracleClient = None # type: ignore[misc,assignment] -class TestTwoDatabases(test.SimpleTestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - if Tortoise._inited: - await self._tearDownDB() - first_db_config = test.getDBConfig(app_label="models", modules=["tests.testmodels"]) - second_db_config = test.getDBConfig(app_label="events", modules=["tests.testmodels"]) - merged_config = { - "connections": {**first_db_config["connections"], **second_db_config["connections"]}, - "apps": {**first_db_config["apps"], **second_db_config["apps"]}, - } - await Tortoise.init(merged_config, _create_db=True) - await Tortoise.generate_schemas() - self.db = connections.get("models") - self.second_db = connections.get("events") - - async def asyncTearDown(self) -> None: - await Tortoise._drop_databases() - await super().asyncTearDown() - - def build_select_sql(self) -> str: - if isinstance(self.db, OracleClient): - return 'SELECT * FROM "eventtwo"' - return "SELECT * FROM eventtwo" - - async def test_two_databases(self): - tournament = await Tournament.create(name="Tournament") - await EventTwo.create(name="Event", tournament_id=tournament.id) - select_sql = self.build_select_sql() - with self.assertRaises(OperationalError): - await self.db.execute_query(select_sql) - _, results = await self.second_db.execute_query(select_sql) - self.assertEqual(dict(results[0]), {"id": 1, "name": "Event", "tournament_id": 1}) +@pytest_asyncio.fixture(scope="function") +async def two_databases(): + """Fixture that sets up two separate databases for testing.""" + db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:") - async def test_two_databases_relation(self): - tournament = await Tournament.create(name="Tournament") - event = await EventTwo.create(name="Event", tournament_id=tournament.id) + from tortoise.backends.base.config_generator import expand_db_url + + # Expand the URL with uniqueness if it's a template + # This ensures "models" and "events" get different DB names if TORTOISE_TEST_DB has {} + db1_config = expand_db_url(db_url, testing=True) + db2_config = expand_db_url(db_url, testing=True) + + ctx = TortoiseContext() + async with ctx: + await ctx.init( + config={ + "connections": { + "models": db1_config, + "events": db2_config, + }, + "apps": { + "models": {"models": ["tests.testmodels"], "default_connection": "models"}, + "events": {"models": ["tests.testmodels"], "default_connection": "events"}, + }, + }, + _create_db=True, + ) + await ctx.generate_schemas() + + db = ctx.connections.get("models") + second_db = ctx.connections.get("events") + + yield db, second_db + + +def build_select_sql(db) -> str: + """Helper function to build SELECT SQL based on database type.""" + if OracleClient is not None and isinstance(db, OracleClient): + return 'SELECT * FROM "eventtwo"' + return "SELECT * FROM eventtwo" + + +@pytest.mark.asyncio +async def test_two_databases(two_databases): + db, second_db = two_databases - select_sql = self.build_select_sql() - with self.assertRaises(OperationalError): - await self.db.execute_query(select_sql) + tournament = await Tournament.create(name="Tournament") + await EventTwo.create(name="Event", tournament_id=tournament.id) - _, results = await self.second_db.execute_query(select_sql) - self.assertEqual(dict(results[0]), {"id": 1, "name": "Event", "tournament_id": 1}) + select_sql = build_select_sql(db) + with pytest.raises(OperationalError): + await db.execute_query(select_sql) + _, results = await second_db.execute_query(select_sql) + assert dict(results[0]) == {"id": 1, "name": "Event", "tournament_id": 1} - teams = [] - for i in range(2): - team = await TeamTwo.create(name=f"Team {(i + 1)}") - teams.append(team) + +@pytest.mark.asyncio +async def test_two_databases_relation(two_databases): + db, second_db = two_databases + + tournament = await Tournament.create(name="Tournament") + event = await EventTwo.create(name="Event", tournament_id=tournament.id) + + select_sql = build_select_sql(db) + with pytest.raises(OperationalError): + await db.execute_query(select_sql) + + _, results = await second_db.execute_query(select_sql) + assert dict(results[0]) == {"id": 1, "name": "Event", "tournament_id": 1} + + teams = [] + for i in range(2): + team = await TeamTwo.create(name=f"Team {(i + 1)}") + teams.append(team) + await event.participants.add(team) + + assert await TeamTwo.all().order_by("name") == teams + assert await event.participants.all().order_by("name") == teams + + assert await TeamTwo.all().order_by("name").values("id", "name") == [ + {"id": 1, "name": "Team 1"}, + {"id": 2, "name": "Team 2"}, + ] + assert await event.participants.all().order_by("name").values("id", "name") == [ + {"id": 1, "name": "Team 1"}, + {"id": 2, "name": "Team 2"}, + ] + + +@pytest.mark.asyncio +async def test_two_databases_transactions_switch_db(two_databases): + async with in_transaction("models"): + tournament = await Tournament.create(name="Tournament") + await Event.create(name="Event1", tournament=tournament) + async with in_transaction("events"): + event = await EventTwo.create(name="Event2", tournament_id=tournament.id) + team = await TeamTwo.create(name="Team 1") await event.participants.add(team) - self.assertEqual(await TeamTwo.all().order_by("name"), teams) - self.assertEqual(await event.participants.all().order_by("name"), teams) + saved_tournament = await Tournament.filter(name="Tournament").first() + assert tournament.id == saved_tournament.id + saved_event = await EventTwo.filter(tournament_id=tournament.id).first() + assert event.id == saved_event.id - self.assertEqual( - await TeamTwo.all().order_by("name").values("id", "name"), - [{"id": 1, "name": "Team 1"}, {"id": 2, "name": "Team 2"}], - ) - self.assertEqual( - await event.participants.all().order_by("name").values("id", "name"), - [{"id": 1, "name": "Team 1"}, {"id": 2, "name": "Team 2"}], - ) - async def test_two_databases_transactions_switch_db(self): - async with in_transaction("models"): - tournament = await Tournament.create(name="Tournament") - await Event.create(name="Event1", tournament=tournament) - async with in_transaction("events"): - event = await EventTwo.create(name="Event2", tournament_id=tournament.id) - team = await TeamTwo.create(name="Team 1") - await event.participants.add(team) - - saved_tournament = await Tournament.filter(name="Tournament").first() - self.assertEqual(tournament.id, saved_tournament.id) - saved_event = await EventTwo.filter(tournament_id=tournament.id).first() - self.assertEqual(event.id, saved_event.id) - - async def test_two_databases_transaction_paramerror(self): - with self.assertRaisesRegex( - ParamsError, - "You are running with multiple databases, so you should specify connection_name", - ): - async with in_transaction(): - pass +@pytest.mark.asyncio +async def test_two_databases_transaction_paramerror(two_databases): + with pytest.raises( + ParamsError, + match="You are running with multiple databases, so you should specify connection_name", + ): + async with in_transaction(): + pass diff --git a/tests/test_unique_together.py b/tests/test_unique_together.py index 1b3c84955..3f80a6b8a 100644 --- a/tests/test_unique_together.py +++ b/tests/test_unique_together.py @@ -1,29 +1,32 @@ +import pytest + from tests.testmodels import ( Tournament, UniqueTogetherFields, UniqueTogetherFieldsWithFK, ) -from tortoise.contrib import test from tortoise.exceptions import IntegrityError -class TestUniqueTogether(test.TestCase): - async def test_unique_together(self): - first_name = "first_name" - last_name = "last_name" +@pytest.mark.asyncio +async def test_unique_together(db): + first_name = "first_name" + last_name = "last_name" + + await UniqueTogetherFields.create(first_name=first_name, last_name=last_name) + with pytest.raises(IntegrityError): await UniqueTogetherFields.create(first_name=first_name, last_name=last_name) - with self.assertRaises(IntegrityError): - await UniqueTogetherFields.create(first_name=first_name, last_name=last_name) - async def test_unique_together_with_foreign_keys(self): - tournament_name = "tournament_name" - text = "text" +@pytest.mark.asyncio +async def test_unique_together_with_foreign_keys(db): + tournament_name = "tournament_name" + text = "text" - tournament = await Tournament.create(name=tournament_name) + tournament = await Tournament.create(name=tournament_name) - await UniqueTogetherFieldsWithFK.create(text=text, tournament=tournament) + await UniqueTogetherFieldsWithFK.create(text=text, tournament=tournament) - with self.assertRaises(IntegrityError): - await UniqueTogetherFieldsWithFK.create(text=text, tournament=tournament) + with pytest.raises(IntegrityError): + await UniqueTogetherFieldsWithFK.create(text=text, tournament=tournament) diff --git a/tests/test_update.py b/tests/test_update.py index 2edcebf9a..634e37821 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -4,6 +4,7 @@ from datetime import datetime, timedelta from typing import Any +import pytest import pytz from pypika_tortoise.terms import Function as PupikaFunction @@ -23,244 +24,272 @@ UUIDFields, ) from tortoise import timezone -from tortoise.contrib import test +from tortoise.contrib.test import requireCapability from tortoise.contrib.test.condition import In, NotEQ from tortoise.expressions import Case, F, Q, Subquery, When from tortoise.functions import Function, Upper -class TestUpdate(test.TestCase): - async def test_update(self): - await Tournament.create(name="1") - await Tournament.create(name="3") - rows_affected = await Tournament.all().update(name="2") - self.assertEqual(rows_affected, 2) - - tournament = await Tournament.first() - self.assertEqual(tournament.name, "2") - - async def test_bulk_update(self): - objs = [await Tournament.create(name="1"), await Tournament.create(name="2")] - objs[0].name = "0" - objs[1].name = "1" - rows_affected = await Tournament.bulk_update(objs, fields=["name"], batch_size=100) - self.assertEqual(rows_affected, 2) - self.assertEqual((await Tournament.get(pk=objs[0].pk)).name, "0") - self.assertEqual((await Tournament.get(pk=objs[1].pk)).name, "1") - - async def test_bulk_update_datetime(self): - objs = [ - await DatetimeFields.create(datetime=datetime(2021, 1, 1, tzinfo=pytz.utc)), - await DatetimeFields.create(datetime=datetime(2021, 1, 1, tzinfo=pytz.utc)), - ] - t0 = datetime(2021, 1, 2, tzinfo=pytz.utc) - t1 = datetime(2021, 1, 3, tzinfo=pytz.utc) - objs[0].datetime = t0 - objs[1].datetime = t1 - rows_affected = await DatetimeFields.bulk_update(objs, fields=["datetime"]) - self.assertEqual(rows_affected, 2) - self.assertEqual((await DatetimeFields.get(pk=objs[0].pk)).datetime, t0) - self.assertEqual((await DatetimeFields.get(pk=objs[1].pk)).datetime, t1) - - async def test_bulk_update_pk_non_id(self): - tournament = await Tournament.create(name="") - events = [ - await Event.create(name="1", tournament=tournament), - await Event.create(name="2", tournament=tournament), - ] - events[0].name = "3" - events[1].name = "4" - rows_affected = await Event.bulk_update(events, fields=["name"]) - self.assertEqual(rows_affected, 2) - self.assertEqual((await Event.get(pk=events[0].pk)).name, events[0].name) - self.assertEqual((await Event.get(pk=events[1].pk)).name, events[1].name) - - async def test_bulk_update_pk_uuid(self): - objs = [ - await UUIDFields.create(data=uuid.uuid4()), - await UUIDFields.create(data=uuid.uuid4()), - ] - objs[0].data = uuid.uuid4() - objs[1].data = uuid.uuid4() - rows_affected = await UUIDFields.bulk_update(objs, fields=["data"]) - self.assertEqual(rows_affected, 2) - self.assertEqual((await UUIDFields.get(pk=objs[0].pk)).data, objs[0].data) - self.assertEqual((await UUIDFields.get(pk=objs[1].pk)).data, objs[1].data) - - async def test_bulk_renamed_pk_source_field(self): - objs = [ - await SourceFieldPk.create(name="Model 1"), - await SourceFieldPk.create(name="Model 2"), - ] - objs[0].name = "Model 3" - objs[1].name = "Model 4" - rows_affected = await SourceFieldPk.bulk_update(objs, fields=["name"]) - self.assertEqual(rows_affected, 2) - self.assertEqual((await SourceFieldPk.get(pk=objs[0].pk)).name, objs[0].name) - self.assertEqual((await SourceFieldPk.get(pk=objs[1].pk)).name, objs[1].name) - - async def test_bulk_update_json_value(self): - objs = [ - await JSONFields.create(data={}), - await JSONFields.create(data={}), - ] - objs[0].data = [0] - objs[1].data = {"a": 1} - rows_affected = await JSONFields.bulk_update(objs, fields=["data"]) - self.assertEqual(rows_affected, 2) - self.assertEqual((await JSONFields.get(pk=objs[0].pk)).data, objs[0].data) - self.assertEqual((await JSONFields.get(pk=objs[1].pk)).data, objs[1].data) - - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_bulk_update_smallint_none(self): - objs = [ - await SmallIntFields.create(smallintnum=1, smallintnum_null=1), - await SmallIntFields.create(smallintnum=2, smallintnum_null=2), - ] - objs[0].smallintnum_null = None - objs[1].smallintnum_null = None - rows_affected = await SmallIntFields.bulk_update(objs, fields=["smallintnum_null"]) - self.assertEqual(rows_affected, 2) - self.assertEqual((await SmallIntFields.get(pk=objs[0].pk)).smallintnum_null, None) - self.assertEqual((await SmallIntFields.get(pk=objs[1].pk)).smallintnum_null, None) - - async def test_bulk_update_custom_field(self): - objs = [ - await EnumFields.create(service=Service.python_programming, currency=Currency.EUR), - await EnumFields.create(service=Service.database_design, currency=Currency.USD), - ] - objs[0].currency = Currency.USD - objs[1].service = Service.system_administration - rows_affected = await EnumFields.bulk_update(objs, fields=["service", "currency"]) - self.assertEqual(rows_affected, 2) - self.assertEqual((await EnumFields.get(pk=objs[0].pk)).currency, Currency.USD) - self.assertEqual( - (await EnumFields.get(pk=objs[1].pk)).service, Service.system_administration - ) - - async def test_update_auto_now(self): - obj = await DefaultUpdate.create() - - updated_at = timezone.now() - timedelta(days=1) - await DefaultUpdate.filter(pk=obj.pk).update(updated_at=updated_at) - - obj1 = await DefaultUpdate.get(pk=obj.pk) - self.assertEqual(obj1.updated_at.date(), updated_at.date()) - - async def test_update_relation(self): - tournament_first = await Tournament.create(name="1") - tournament_second = await Tournament.create(name="2") - - await Event.create(name="1", tournament=tournament_first) - await Event.all().update(tournament=tournament_second) - event = await Event.first() - self.assertEqual(event.tournament_id, tournament_second.id) - - @test.requireCapability(dialect=In("mysql", "sqlite")) - async def test_update_with_custom_function(self): - class JsonSet(Function): - class PypikaJsonSet(PupikaFunction): - def __init__(self, field: F, expression: str, value: Any): - super().__init__("JSON_SET", field, expression, value) - - database_func = PypikaJsonSet - - json = await JSONFields.create(data={}) - self.assertEqual(json.data_default, {"a": 1}) - - json.data_default = JsonSet(F("data_default"), "$.a", 2) - await json.save() - - json_update = await JSONFields.get(pk=json.pk) - self.assertEqual(json_update.data_default, {"a": 2}) - - await JSONFields.filter(pk=json.pk).update( - data_default=JsonSet(F("data_default"), "$.a", 3) - ) - json_update = await JSONFields.get(pk=json.pk) - self.assertEqual(json_update.data_default, {"a": 3}) - - async def test_refresh_from_db(self): - int_field = await IntFields.create(intnum=1, intnum_null=2) - int_field_in_db = await IntFields.get(pk=int_field.pk) - int_field_in_db.intnum = F("intnum") + 1 - await int_field_in_db.save(update_fields=["intnum"]) - self.assertIsNot(int_field_in_db.intnum, 2) - self.assertIs(int_field_in_db.intnum_null, 2) - - await int_field_in_db.refresh_from_db(fields=["intnum"]) - self.assertIs(int_field_in_db.intnum, 2) - self.assertIs(int_field_in_db.intnum_null, 2) - - int_field_in_db.intnum = F("intnum") + 1 - await int_field_in_db.save() - self.assertIsNot(int_field_in_db.intnum, 3) - self.assertIs(int_field_in_db.intnum_null, 2) - - await int_field_in_db.refresh_from_db() - self.assertIs(int_field_in_db.intnum, 3) - self.assertIs(int_field_in_db.intnum_null, 2) - - @test.requireCapability(support_update_limit_order_by=True) - async def test_update_with_limit_ordering(self): - await Tournament.create(name="1") - t2 = await Tournament.create(name="1") - await Tournament.filter(name="1").limit(1).order_by("-id").update(name="2") - self.assertIs((await Tournament.get(pk=t2.pk)).name, "2") - self.assertEqual(await Tournament.filter(name="1").count(), 1) - - # tortoise-pypika does not translate ** to POWER in MSSQL - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_update_with_case_when_and_f(self): - event1 = await IntFields.create(intnum=1) - event2 = await IntFields.create(intnum=2) - event3 = await IntFields.create(intnum=3) - await ( - IntFields.all() - .annotate( - intnum_updated=Case( - When( - Q(intnum=1), - then=F("intnum") + 1, - ), - When( - Q(intnum=2), - then=F("intnum") * 2, - ), - default=F("intnum") ** 3, - ) +@pytest.mark.asyncio +async def test_update(db): + await Tournament.create(name="1") + await Tournament.create(name="3") + rows_affected = await Tournament.all().update(name="2") + assert rows_affected == 2 + + tournament = await Tournament.first() + assert tournament.name == "2" + + +@pytest.mark.asyncio +async def test_bulk_update(db): + objs = [await Tournament.create(name="1"), await Tournament.create(name="2")] + objs[0].name = "0" + objs[1].name = "1" + rows_affected = await Tournament.bulk_update(objs, fields=["name"], batch_size=100) + assert rows_affected == 2 + assert (await Tournament.get(pk=objs[0].pk)).name == "0" + assert (await Tournament.get(pk=objs[1].pk)).name == "1" + + +@pytest.mark.asyncio +async def test_bulk_update_datetime(db): + objs = [ + await DatetimeFields.create(datetime=datetime(2021, 1, 1, tzinfo=pytz.utc)), + await DatetimeFields.create(datetime=datetime(2021, 1, 1, tzinfo=pytz.utc)), + ] + t0 = datetime(2021, 1, 2, tzinfo=pytz.utc) + t1 = datetime(2021, 1, 3, tzinfo=pytz.utc) + objs[0].datetime = t0 + objs[1].datetime = t1 + rows_affected = await DatetimeFields.bulk_update(objs, fields=["datetime"]) + assert rows_affected == 2 + assert (await DatetimeFields.get(pk=objs[0].pk)).datetime == t0 + assert (await DatetimeFields.get(pk=objs[1].pk)).datetime == t1 + + +@pytest.mark.asyncio +async def test_bulk_update_pk_non_id(db): + tournament = await Tournament.create(name="") + events = [ + await Event.create(name="1", tournament=tournament), + await Event.create(name="2", tournament=tournament), + ] + events[0].name = "3" + events[1].name = "4" + rows_affected = await Event.bulk_update(events, fields=["name"]) + assert rows_affected == 2 + assert (await Event.get(pk=events[0].pk)).name == events[0].name + assert (await Event.get(pk=events[1].pk)).name == events[1].name + + +@pytest.mark.asyncio +async def test_bulk_update_pk_uuid(db): + objs = [ + await UUIDFields.create(data=uuid.uuid4()), + await UUIDFields.create(data=uuid.uuid4()), + ] + objs[0].data = uuid.uuid4() + objs[1].data = uuid.uuid4() + rows_affected = await UUIDFields.bulk_update(objs, fields=["data"]) + assert rows_affected == 2 + assert (await UUIDFields.get(pk=objs[0].pk)).data == objs[0].data + assert (await UUIDFields.get(pk=objs[1].pk)).data == objs[1].data + + +@pytest.mark.asyncio +async def test_bulk_renamed_pk_source_field(db): + objs = [ + await SourceFieldPk.create(name="Model 1"), + await SourceFieldPk.create(name="Model 2"), + ] + objs[0].name = "Model 3" + objs[1].name = "Model 4" + rows_affected = await SourceFieldPk.bulk_update(objs, fields=["name"]) + assert rows_affected == 2 + assert (await SourceFieldPk.get(pk=objs[0].pk)).name == objs[0].name + assert (await SourceFieldPk.get(pk=objs[1].pk)).name == objs[1].name + + +@pytest.mark.asyncio +async def test_bulk_update_json_value(db): + objs = [ + await JSONFields.create(data={}), + await JSONFields.create(data={}), + ] + objs[0].data = [0] + objs[1].data = {"a": 1} + rows_affected = await JSONFields.bulk_update(objs, fields=["data"]) + assert rows_affected == 2 + assert (await JSONFields.get(pk=objs[0].pk)).data == objs[0].data + assert (await JSONFields.get(pk=objs[1].pk)).data == objs[1].data + + +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_bulk_update_smallint_none(db): + objs = [ + await SmallIntFields.create(smallintnum=1, smallintnum_null=1), + await SmallIntFields.create(smallintnum=2, smallintnum_null=2), + ] + objs[0].smallintnum_null = None + objs[1].smallintnum_null = None + rows_affected = await SmallIntFields.bulk_update(objs, fields=["smallintnum_null"]) + assert rows_affected == 2 + assert (await SmallIntFields.get(pk=objs[0].pk)).smallintnum_null is None + assert (await SmallIntFields.get(pk=objs[1].pk)).smallintnum_null is None + + +@pytest.mark.asyncio +async def test_bulk_update_custom_field(db): + objs = [ + await EnumFields.create(service=Service.python_programming, currency=Currency.EUR), + await EnumFields.create(service=Service.database_design, currency=Currency.USD), + ] + objs[0].currency = Currency.USD + objs[1].service = Service.system_administration + rows_affected = await EnumFields.bulk_update(objs, fields=["service", "currency"]) + assert rows_affected == 2 + assert (await EnumFields.get(pk=objs[0].pk)).currency == Currency.USD + assert (await EnumFields.get(pk=objs[1].pk)).service == Service.system_administration + + +@pytest.mark.asyncio +async def test_update_auto_now(db): + obj = await DefaultUpdate.create() + + updated_at = timezone.now() - timedelta(days=1) + await DefaultUpdate.filter(pk=obj.pk).update(updated_at=updated_at) + + obj1 = await DefaultUpdate.get(pk=obj.pk) + assert obj1.updated_at.date() == updated_at.date() + + +@pytest.mark.asyncio +async def test_update_relation(db): + tournament_first = await Tournament.create(name="1") + tournament_second = await Tournament.create(name="2") + + await Event.create(name="1", tournament=tournament_first) + await Event.all().update(tournament=tournament_second) + event = await Event.first() + assert event.tournament_id == tournament_second.id + + +@requireCapability(dialect=In("mysql", "sqlite")) +@pytest.mark.asyncio +async def test_update_with_custom_function(db): + class JsonSet(Function): + class PypikaJsonSet(PupikaFunction): + def __init__(self, field: F, expression: str, value: Any): + super().__init__("JSON_SET", field, expression, value) + + database_func = PypikaJsonSet + + json = await JSONFields.create(data={}) + assert json.data_default == {"a": 1} + + json.data_default = JsonSet(F("data_default"), "$.a", 2) + await json.save() + + json_update = await JSONFields.get(pk=json.pk) + assert json_update.data_default == {"a": 2} + + await JSONFields.filter(pk=json.pk).update(data_default=JsonSet(F("data_default"), "$.a", 3)) + json_update = await JSONFields.get(pk=json.pk) + assert json_update.data_default == {"a": 3} + + +@pytest.mark.asyncio +async def test_refresh_from_db(db): + int_field = await IntFields.create(intnum=1, intnum_null=2) + int_field_in_db = await IntFields.get(pk=int_field.pk) + int_field_in_db.intnum = F("intnum") + 1 + await int_field_in_db.save(update_fields=["intnum"]) + assert int_field_in_db.intnum != 2 + assert int_field_in_db.intnum_null == 2 + + await int_field_in_db.refresh_from_db(fields=["intnum"]) + assert int_field_in_db.intnum == 2 + assert int_field_in_db.intnum_null == 2 + + int_field_in_db.intnum = F("intnum") + 1 + await int_field_in_db.save() + assert int_field_in_db.intnum != 3 + assert int_field_in_db.intnum_null == 2 + + await int_field_in_db.refresh_from_db() + assert int_field_in_db.intnum == 3 + assert int_field_in_db.intnum_null == 2 + + +@requireCapability(support_update_limit_order_by=True) +@pytest.mark.asyncio +async def test_update_with_limit_ordering(db): + await Tournament.create(name="1") + t2 = await Tournament.create(name="1") + await Tournament.filter(name="1").limit(1).order_by("-id").update(name="2") + assert (await Tournament.get(pk=t2.pk)).name == "2" + assert await Tournament.filter(name="1").count() == 1 + + +# tortoise-pypika does not translate ** to POWER in MSSQL +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_update_with_case_when_and_f(db): + event1 = await IntFields.create(intnum=1) + event2 = await IntFields.create(intnum=2) + event3 = await IntFields.create(intnum=3) + await ( + IntFields.all() + .annotate( + intnum_updated=Case( + When( + Q(intnum=1), + then=F("intnum") + 1, + ), + When( + Q(intnum=2), + then=F("intnum") * 2, + ), + default=F("intnum") ** 3, ) - .update(intnum=F("intnum_updated")) ) - - for e in [event1, event2, event3]: - await e.refresh_from_db() - self.assertEqual(event1.intnum, 2) - self.assertEqual(event2.intnum, 4) - self.assertEqual(event3.intnum, 27) - - async def test_update_with_function_annotation(self): - tournament = await Tournament.create(name="aaa") - await ( - Tournament.filter(pk=tournament.pk) - .annotate( - upped_name=Upper(F("name")), - ) - .update(name=F("upped_name")) + .update(intnum=F("intnum_updated")) + ) + + for e in [event1, event2, event3]: + await e.refresh_from_db() + assert event1.intnum == 2 + assert event2.intnum == 4 + assert event3.intnum == 27 + + +@pytest.mark.asyncio +async def test_update_with_function_annotation(db): + tournament = await Tournament.create(name="aaa") + await ( + Tournament.filter(pk=tournament.pk) + .annotate( + upped_name=Upper(F("name")), ) - self.assertEqual((await Tournament.get(pk=tournament.pk)).name, "AAA") - - async def test_update_with_filter_subquery(self): - t1 = await Tournament.create(name="1") - r1 = await Reporter.create(name="1") - e1 = await Event.create(name="1", tournament=t1, reporter=r1) - - # NOTE: this is intentionally written with Subquery and known PKs to test - # whether subqueries are parameterized correctly. - await Event.filter( - tournament_id__in=Subquery(Tournament.filter(pk__in=[t1.pk]).values("id")), - reporter_id__in=Subquery(Reporter.filter(pk__in=[r1.pk]).values("id")), - ).update(token="hello_world") - - await e1.refresh_from_db(fields=["token"]) - self.assertEqual(e1.token, "hello_world") + .update(name=F("upped_name")) + ) + assert (await Tournament.get(pk=tournament.pk)).name == "AAA" + + +@pytest.mark.asyncio +async def test_update_with_filter_subquery(db): + t1 = await Tournament.create(name="1") + r1 = await Reporter.create(name="1") + e1 = await Event.create(name="1", tournament=t1, reporter=r1) + + # NOTE: this is intentionally written with Subquery and known PKs to test + # whether subqueries are parameterized correctly. + await Event.filter( + tournament_id__in=Subquery(Tournament.filter(pk__in=[t1.pk]).values("id")), + reporter_id__in=Subquery(Reporter.filter(pk__in=[r1.pk]).values("id")), + ).update(token="hello_world") + + await e1.refresh_from_db(fields=["token"]) + assert e1.token == "hello_world" diff --git a/tests/test_validators.py b/tests/test_validators.py index 51703a71c..64596aa1f 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,82 +1,103 @@ from decimal import Decimal +import pytest + from tests.testmodels import ValidatorModel -from tortoise.contrib import test from tortoise.exceptions import ValidationError -class TestValues(test.TestCase): - async def test_validator_regex(self): - with self.assertRaises(ValidationError): - await ValidatorModel.create(regex="ccc") - await ValidatorModel.create(regex="abcd") - - async def test_validator_max_length(self): - with self.assertRaises(ValidationError): - await ValidatorModel.create(max_length="aaaaaa") - await ValidatorModel.create(max_length="aaaaa") - - async def test_validator_min_value(self): - # min value is 10 - with self.assertRaises(ValidationError): - await ValidatorModel.create(min_value=9) - await ValidatorModel.create(min_value=10) - - # min value is Decimal("1.0") - with self.assertRaises(ValidationError): - await ValidatorModel.create(min_value_decimal=Decimal("0.9")) - await ValidatorModel.create(min_value_decimal=Decimal("1.0")) - - async def test_validator_max_value(self): - # max value is 20 - with self.assertRaises(ValidationError): - await ValidatorModel.create(max_value=21) - await ValidatorModel.create(max_value=20) - - # max value is Decimal("2.0") - with self.assertRaises(ValidationError): - await ValidatorModel.create(max_value_decimal=Decimal("3.0")) - await ValidatorModel.create(max_value_decimal=Decimal("2.0")) - - async def test_validator_ipv4(self): - with self.assertRaises(ValidationError): - await ValidatorModel.create(ipv4="aaaaaa") - await ValidatorModel.create(ipv4="8.8.8.8") - - async def test_validator_ipv6(self): - with self.assertRaises(ValidationError): - await ValidatorModel.create(ipv6="aaaaaa") - await ValidatorModel.create(ipv6="::") - - async def test_validator_comma_separated_integer_list(self): - with self.assertRaises(ValidationError): - await ValidatorModel.create(comma_separated_integer_list="aaaaaa") - await ValidatorModel.create(comma_separated_integer_list="1,2,3") - - async def test__prevent_saving(self): - with self.assertRaises(ValidationError): - await ValidatorModel.create(min_value_decimal=Decimal("0.9")) - - self.assertEqual(await ValidatorModel.all().count(), 0) - - async def test_save(self): - with self.assertRaises(ValidationError): - record = ValidatorModel(min_value_decimal=Decimal("0.9")) - await record.save() - - record.min_value_decimal = Decimal("1.5") +@pytest.mark.asyncio +async def test_validator_regex(db): + with pytest.raises(ValidationError): + await ValidatorModel.create(regex="ccc") + await ValidatorModel.create(regex="abcd") + + +@pytest.mark.asyncio +async def test_validator_max_length(db): + with pytest.raises(ValidationError): + await ValidatorModel.create(max_length="aaaaaa") + await ValidatorModel.create(max_length="aaaaa") + + +@pytest.mark.asyncio +async def test_validator_min_value(db): + # min value is 10 + with pytest.raises(ValidationError): + await ValidatorModel.create(min_value=9) + await ValidatorModel.create(min_value=10) + + # min value is Decimal("1.0") + with pytest.raises(ValidationError): + await ValidatorModel.create(min_value_decimal=Decimal("0.9")) + await ValidatorModel.create(min_value_decimal=Decimal("1.0")) + + +@pytest.mark.asyncio +async def test_validator_max_value(db): + # max value is 20 + with pytest.raises(ValidationError): + await ValidatorModel.create(max_value=21) + await ValidatorModel.create(max_value=20) + + # max value is Decimal("2.0") + with pytest.raises(ValidationError): + await ValidatorModel.create(max_value_decimal=Decimal("3.0")) + await ValidatorModel.create(max_value_decimal=Decimal("2.0")) + + +@pytest.mark.asyncio +async def test_validator_ipv4(db): + with pytest.raises(ValidationError): + await ValidatorModel.create(ipv4="aaaaaa") + await ValidatorModel.create(ipv4="8.8.8.8") + + +@pytest.mark.asyncio +async def test_validator_ipv6(db): + with pytest.raises(ValidationError): + await ValidatorModel.create(ipv6="aaaaaa") + await ValidatorModel.create(ipv6="::") + + +@pytest.mark.asyncio +async def test_validator_comma_separated_integer_list(db): + with pytest.raises(ValidationError): + await ValidatorModel.create(comma_separated_integer_list="aaaaaa") + await ValidatorModel.create(comma_separated_integer_list="1,2,3") + + +@pytest.mark.asyncio +async def test_prevent_saving(db): + with pytest.raises(ValidationError): + await ValidatorModel.create(min_value_decimal=Decimal("0.9")) + + assert await ValidatorModel.all().count() == 0 + + +@pytest.mark.asyncio +async def test_save(db): + with pytest.raises(ValidationError): + record = ValidatorModel(min_value_decimal=Decimal("0.9")) await record.save() - async def test_save_with_update_fields(self): - record = await ValidatorModel.create(min_value_decimal=Decimal("2")) + record.min_value_decimal = Decimal("1.5") + await record.save() + - record.min_value_decimal = Decimal("0.9") - with self.assertRaises(ValidationError): - await record.save(update_fields=["min_value_decimal"]) +@pytest.mark.asyncio +async def test_save_with_update_fields(db): + record = await ValidatorModel.create(min_value_decimal=Decimal("2")) - async def test_update(self): - record = await ValidatorModel.create(min_value_decimal=Decimal("2")) + record.min_value_decimal = Decimal("0.9") + with pytest.raises(ValidationError): + await record.save(update_fields=["min_value_decimal"]) - record.min_value_decimal = Decimal("0.9") - with self.assertRaises(ValidationError): - await record.save() + +@pytest.mark.asyncio +async def test_update(db): + record = await ValidatorModel.create(min_value_decimal=Decimal("2")) + + record.min_value_decimal = Decimal("0.9") + with pytest.raises(ValidationError): + await record.save() diff --git a/tests/test_values.py b/tests/test_values.py index b681b318b..6a5ffeccf 100644 --- a/tests/test_values.py +++ b/tests/test_values.py @@ -1,3 +1,4 @@ +import pytest from pypika_tortoise import CustomFunction from tests.testmodels import Event, Team, Tournament @@ -8,247 +9,282 @@ from tortoise.functions import Length, Trim -class TestValues(test.TestCase): - async def test_values_related_fk(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) +@pytest.mark.asyncio +async def test_values_related_fk(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) - event2 = await Event.filter(name="Test").values("name", "tournament__name") - self.assertEqual(event2[0], {"name": "Test", "tournament__name": "New Tournament"}) + event2 = await Event.filter(name="Test").values("name", "tournament__name") + assert event2[0] == {"name": "Test", "tournament__name": "New Tournament"} - async def test_values_list_related_fk(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) - event2 = await Event.filter(name="Test").values_list("name", "tournament__name") - self.assertEqual(event2[0], ("Test", "New Tournament")) +@pytest.mark.asyncio +async def test_values_list_related_fk(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) - async def test_values_related_rfk(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) + event2 = await Event.filter(name="Test").values_list("name", "tournament__name") + assert event2[0] == ("Test", "New Tournament") - tournament2 = await Tournament.filter(name="New Tournament").values("name", "events__name") - self.assertEqual(tournament2[0], {"name": "New Tournament", "events__name": "Test"}) - async def test_values_related_rfk_reuse_query(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) +@pytest.mark.asyncio +async def test_values_related_rfk(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) - query = Tournament.filter(name="New Tournament").values("name", "events__name") - tournament2 = await query - self.assertEqual(tournament2[0], {"name": "New Tournament", "events__name": "Test"}) + tournament2 = await Tournament.filter(name="New Tournament").values("name", "events__name") + assert tournament2[0] == {"name": "New Tournament", "events__name": "Test"} - tournament2 = await query - self.assertEqual(tournament2[0], {"name": "New Tournament", "events__name": "Test"}) - async def test_values_list_related_rfk(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) +@pytest.mark.asyncio +async def test_values_related_rfk_reuse_query(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) - tournament2 = await Tournament.filter(name="New Tournament").values_list( - "name", "events__name" - ) - self.assertEqual(tournament2[0], ("New Tournament", "Test")) + query = Tournament.filter(name="New Tournament").values("name", "events__name") + tournament2 = await query + assert tournament2[0] == {"name": "New Tournament", "events__name": "Test"} - async def test_values_list_related_rfk_reuse_query(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) + tournament2 = await query + assert tournament2[0] == {"name": "New Tournament", "events__name": "Test"} - query = Tournament.filter(name="New Tournament").values_list("name", "events__name") - tournament2 = await query - self.assertEqual(tournament2[0], ("New Tournament", "Test")) - tournament2 = await query - self.assertEqual(tournament2[0], ("New Tournament", "Test")) +@pytest.mark.asyncio +async def test_values_list_related_rfk(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) - async def test_values_related_m2m(self): - tournament = await Tournament.create(name="New Tournament") - event = await Event.create(name="Test", tournament_id=tournament.id) - team = await Team.create(name="Some Team") - await event.participants.add(team) + tournament2 = await Tournament.filter(name="New Tournament").values_list("name", "events__name") + assert tournament2[0] == ("New Tournament", "Test") - tournament2 = await Event.filter(name="Test").values("name", "participants__name") - self.assertEqual(tournament2[0], {"name": "Test", "participants__name": "Some Team"}) - async def test_values_list_related_m2m(self): - tournament = await Tournament.create(name="New Tournament") - event = await Event.create(name="Test", tournament_id=tournament.id) - team = await Team.create(name="Some Team") - await event.participants.add(team) +@pytest.mark.asyncio +async def test_values_list_related_rfk_reuse_query(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) - tournament2 = await Event.filter(name="Test").values_list("name", "participants__name") - self.assertEqual(tournament2[0], ("Test", "Some Team")) + query = Tournament.filter(name="New Tournament").values_list("name", "events__name") + tournament2 = await query + assert tournament2[0] == ("New Tournament", "Test") - async def test_values_related_fk_itself(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) + tournament2 = await query + assert tournament2[0] == ("New Tournament", "Test") - with self.assertRaisesRegex(ValueError, 'Selecting relation "tournament" is not possible'): - await Event.filter(name="Test").values("name", "tournament") - async def test_values_list_related_fk_itself(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) +@pytest.mark.asyncio +async def test_values_related_m2m(db): + tournament = await Tournament.create(name="New Tournament") + event = await Event.create(name="Test", tournament_id=tournament.id) + team = await Team.create(name="Some Team") + await event.participants.add(team) - with self.assertRaisesRegex(ValueError, 'Selecting relation "tournament" is not possible'): - await Event.filter(name="Test").values_list("name", "tournament") + tournament2 = await Event.filter(name="Test").values("name", "participants__name") + assert tournament2[0] == {"name": "Test", "participants__name": "Some Team"} - async def test_values_related_rfk_itself(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) - with self.assertRaisesRegex(ValueError, 'Selecting relation "events" is not possible'): - await Tournament.filter(name="New Tournament").values("name", "events") +@pytest.mark.asyncio +async def test_values_list_related_m2m(db): + tournament = await Tournament.create(name="New Tournament") + event = await Event.create(name="Test", tournament_id=tournament.id) + team = await Team.create(name="Some Team") + await event.participants.add(team) - async def test_values_list_related_rfk_itself(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) + tournament2 = await Event.filter(name="Test").values_list("name", "participants__name") + assert tournament2[0] == ("Test", "Some Team") - with self.assertRaisesRegex(ValueError, 'Selecting relation "events" is not possible'): - await Tournament.filter(name="New Tournament").values_list("name", "events") - async def test_values_related_m2m_itself(self): - tournament = await Tournament.create(name="New Tournament") - event = await Event.create(name="Test", tournament_id=tournament.id) - team = await Team.create(name="Some Team") - await event.participants.add(team) +@pytest.mark.asyncio +async def test_values_related_fk_itself(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) - with self.assertRaisesRegex( - ValueError, 'Selecting relation "participants" is not possible' - ): - await Event.filter(name="Test").values("name", "participants") + with pytest.raises(ValueError, match='Selecting relation "tournament" is not possible'): + await Event.filter(name="Test").values("name", "tournament") - async def test_values_list_related_m2m_itself(self): - tournament = await Tournament.create(name="New Tournament") - event = await Event.create(name="Test", tournament_id=tournament.id) - team = await Team.create(name="Some Team") - await event.participants.add(team) - with self.assertRaisesRegex( - ValueError, 'Selecting relation "participants" is not possible' - ): - await Event.filter(name="Test").values_list("name", "participants") +@pytest.mark.asyncio +async def test_values_list_related_fk_itself(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) - async def test_values_bad_key(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) + with pytest.raises(ValueError, match='Selecting relation "tournament" is not possible'): + await Event.filter(name="Test").values_list("name", "tournament") - with self.assertRaisesRegex(FieldError, 'Unknown field "neem" for model "Event"'): - await Event.filter(name="Test").values("name", "neem") - async def test_values_list_bad_key(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) +@pytest.mark.asyncio +async def test_values_related_rfk_itself(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) - with self.assertRaisesRegex(FieldError, 'Unknown field "neem" for model "Event"'): - await Event.filter(name="Test").values_list("name", "neem") + with pytest.raises(ValueError, match='Selecting relation "events" is not possible'): + await Tournament.filter(name="New Tournament").values("name", "events") - async def test_values_related_bad_key(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) - with self.assertRaisesRegex(FieldError, 'Unknown field "neem" for model "Tournament"'): - await Event.filter(name="Test").values("name", "tournament__neem") +@pytest.mark.asyncio +async def test_values_list_related_rfk_itself(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) - async def test_values_list_related_bad_key(self): - tournament = await Tournament.create(name="New Tournament") - await Event.create(name="Test", tournament_id=tournament.id) + with pytest.raises(ValueError, match='Selecting relation "events" is not possible'): + await Tournament.filter(name="New Tournament").values_list("name", "events") - with self.assertRaisesRegex(FieldError, 'Unknown field "neem" for model "Tournament"'): - await Event.filter(name="Test").values_list("name", "tournament__neem") - @test.requireCapability(dialect="!mssql") - async def test_values_list_annotations_length(self): - await Tournament.create(name="Championship") - await Tournament.create(name="Super Bowl") +@pytest.mark.asyncio +async def test_values_related_m2m_itself(db): + tournament = await Tournament.create(name="New Tournament") + event = await Event.create(name="Test", tournament_id=tournament.id) + team = await Team.create(name="Some Team") + await event.participants.add(team) - tournaments = await Tournament.annotate(name_length=Length("name")).values_list( - "name", "name_length" - ) - self.assertListSortEqual(tournaments, [("Championship", 12), ("Super Bowl", 10)]) + with pytest.raises(ValueError, match='Selecting relation "participants" is not possible'): + await Event.filter(name="Test").values("name", "participants") - @test.requireCapability(dialect=NotEQ("mssql")) - async def test_values_annotations_length(self): - await Tournament.create(name="Championship") - await Tournament.create(name="Super Bowl") - tournaments = await Tournament.annotate(name_slength=Length("name")).values( - "name", "name_slength" - ) - self.assertListSortEqual( - tournaments, - [ - {"name": "Championship", "name_slength": 12}, - {"name": "Super Bowl", "name_slength": 10}, - ], - sorted_key="name", - ) +@pytest.mark.asyncio +async def test_values_list_related_m2m_itself(db): + tournament = await Tournament.create(name="New Tournament") + event = await Event.create(name="Test", tournament_id=tournament.id) + team = await Team.create(name="Some Team") + await event.participants.add(team) - async def test_values_list_annotations_trim(self): - await Tournament.create(name=" x") - await Tournament.create(name=" y ") + with pytest.raises(ValueError, match='Selecting relation "participants" is not possible'): + await Event.filter(name="Test").values_list("name", "participants") - tournaments = await Tournament.annotate(name_trim=Trim("name")).values_list( - "name", "name_trim" - ) - self.assertListSortEqual(tournaments, [(" x", "x"), (" y ", "y")]) - async def test_values_annotations_trim(self): - await Tournament.create(name=" x") - await Tournament.create(name=" y ") +@pytest.mark.asyncio +async def test_values_bad_key(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) - tournaments = await Tournament.annotate(name_trim=Trim("name")).values("name", "name_trim") - self.assertListSortEqual( - tournaments, - [{"name": " x", "name_trim": "x"}, {"name": " y ", "name_trim": "y"}], - sorted_key="name", - ) + with pytest.raises(FieldError, match='Unknown field "neem" for model "Event"'): + await Event.filter(name="Test").values("name", "neem") - @test.requireCapability(dialect=In("sqlite")) - async def test_values_with_custom_function(self): - class TruncMonth(Function): - database_func = CustomFunction("DATE_FORMAT", ["name", "dt_format"]) - sql = Tournament.all().annotate(date=TruncMonth("created", "%Y-%m-%d")).values("date").sql() - self.assertEqual( - sql, - 'SELECT DATE_FORMAT("created",?) "date" FROM "tournament"', - ) +@pytest.mark.asyncio +async def test_values_list_bad_key(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) + + with pytest.raises(FieldError, match='Unknown field "neem" for model "Event"'): + await Event.filter(name="Test").values_list("name", "neem") + + +@pytest.mark.asyncio +async def test_values_related_bad_key(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) + + with pytest.raises(FieldError, match='Unknown field "neem" for model "Tournament"'): + await Event.filter(name="Test").values("name", "tournament__neem") + + +@pytest.mark.asyncio +async def test_values_list_related_bad_key(db): + tournament = await Tournament.create(name="New Tournament") + await Event.create(name="Test", tournament_id=tournament.id) + + with pytest.raises(FieldError, match='Unknown field "neem" for model "Tournament"'): + await Event.filter(name="Test").values_list("name", "tournament__neem") + + +@test.requireCapability(dialect="!mssql") +@pytest.mark.asyncio +async def test_values_list_annotations_length(db): + await Tournament.create(name="Championship") + await Tournament.create(name="Super Bowl") + + tournaments = await Tournament.annotate(name_length=Length("name")).values_list( + "name", "name_length" + ) + assert sorted(tournaments) == sorted([("Championship", 12), ("Super Bowl", 10)]) + + +@test.requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_values_annotations_length(db): + await Tournament.create(name="Championship") + await Tournament.create(name="Super Bowl") + + tournaments = await Tournament.annotate(name_slength=Length("name")).values( + "name", "name_slength" + ) + assert sorted(tournaments, key=lambda x: x["name"]) == sorted( + [ + {"name": "Championship", "name_slength": 12}, + {"name": "Super Bowl", "name_slength": 10}, + ], + key=lambda x: x["name"], + ) + + +@pytest.mark.asyncio +async def test_values_list_annotations_trim(db): + await Tournament.create(name=" x") + await Tournament.create(name=" y ") + + tournaments = await Tournament.annotate(name_trim=Trim("name")).values_list("name", "name_trim") + assert sorted(tournaments) == sorted([(" x", "x"), (" y ", "y")]) + + +@pytest.mark.asyncio +async def test_values_annotations_trim(db): + await Tournament.create(name=" x") + await Tournament.create(name=" y ") + + tournaments = await Tournament.annotate(name_trim=Trim("name")).values("name", "name_trim") + assert sorted(tournaments, key=lambda x: x["name"]) == sorted( + [{"name": " x", "name_trim": "x"}, {"name": " y ", "name_trim": "y"}], + key=lambda x: x["name"], + ) + + +@test.requireCapability(dialect=In("sqlite")) +@pytest.mark.asyncio +async def test_values_with_custom_function(db): + class TruncMonth(Function): + database_func = CustomFunction("DATE_FORMAT", ["name", "dt_format"]) + + sql = Tournament.all().annotate(date=TruncMonth("created", "%Y-%m-%d")).values("date").sql() + assert sql == 'SELECT DATE_FORMAT("created",?) "date" FROM "tournament"' + + +@pytest.mark.asyncio +async def test_order_by_annotation_not_in_values(db): + await Tournament.create(name="2") + await Tournament.create(name="3") + await Tournament.create(name="1") - async def test_order_by_annotation_not_in_values(self): - await Tournament.create(name="2") - await Tournament.create(name="3") - await Tournament.create(name="1") - - tournaments = ( - await Tournament.annotate( - name_orderable=Case( - When(Q(name="1"), then="a"), - When(Q(name="2"), then="b"), - When(Q(name="3"), then="c"), - default="z", - ) + tournaments = ( + await Tournament.annotate( + name_orderable=Case( + When(Q(name="1"), then="a"), + When(Q(name="2"), then="b"), + When(Q(name="3"), then="c"), + default="z", ) - .order_by("name_orderable") - .values("name") ) - self.assertEqual([t["name"] for t in tournaments], ["1", "2", "3"]) - - async def test_order_by_annotation_not_in_values_list(self): - await Tournament.create(name="2") - await Tournament.create(name="3") - await Tournament.create(name="1") - - tournaments = ( - await Tournament.annotate( - name_orderable=Case( - When(Q(name="1"), then="a"), - When(Q(name="2"), then="b"), - When(Q(name="3"), then="c"), - default="z", - ) + .order_by("name_orderable") + .values("name") + ) + assert [t["name"] for t in tournaments] == ["1", "2", "3"] + + +@pytest.mark.asyncio +async def test_order_by_annotation_not_in_values_list(db): + await Tournament.create(name="2") + await Tournament.create(name="3") + await Tournament.create(name="1") + + tournaments = ( + await Tournament.annotate( + name_orderable=Case( + When(Q(name="1"), then="a"), + When(Q(name="2"), then="b"), + When(Q(name="3"), then="c"), + default="z", ) - .order_by("name_orderable") - .values_list("name") ) - self.assertEqual(tournaments, [("1",), ("2",), ("3",)]) + .order_by("name_orderable") + .values_list("name") + ) + assert tournaments == [("1",), ("2",), ("3",)] diff --git a/tests/utils/test_describe_model.py b/tests/utils/test_describe_model.py index 1dfc95975..97ce90a9f 100644 --- a/tests/utils/test_describe_model.py +++ b/tests/utils/test_describe_model.py @@ -3,6 +3,8 @@ import json import uuid +import pytest + from tests.testmodels import ( Event, JSONFields, @@ -20,7 +22,6 @@ json_pydantic_default, ) from tortoise import Tortoise, fields -from tortoise.contrib import test from tortoise.fields.relational import ( BackwardFKRelation, ForeignKeyFieldInstance, @@ -36,1601 +37,1572 @@ def union_annotation(x: str, y: str) -> str: UNION_DICT_LIST = union_annotation("dict", "list") -class TestDescribeModels(test.TestCase): - def test_describe_models_all_serializable(self): - val = Tortoise.describe_models() +# Tests that require database connection (TestDescribeModels) +@pytest.mark.asyncio +async def test_describe_models_all_serializable(db): + val = Tortoise.describe_models() + json.dumps(val) + assert "models.SourceFields" in val.keys() + assert "models.Event" in val.keys() + + +@pytest.mark.asyncio +async def test_describe_models_all_not_serializable(db): + val = Tortoise.describe_models(serializable=False) + with pytest.raises(TypeError, match="not JSON serializable"): json.dumps(val) - self.assertIn("models.SourceFields", val.keys()) - self.assertIn("models.Event", val.keys()) + assert "models.SourceFields" in val.keys() + assert "models.Event" in val.keys() + + +# Tests that don't require database connection (TestDescribeModel - SimpleTestCase equivalent) +def test_describe_field_noninit_ser(): + field = fields.IntField(primary_key=True) + assert field.describe(serializable=True) == { + "name": "", + "field_type": "IntField", + "db_column": "", + "db_field_types": {"": "INT"}, + "python_type": "int", + "generated": True, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": None, + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + } + + +def test_describe_field_noninit(): + field = fields.IntField(primary_key=True) + assert field.describe(serializable=False) == { + "name": "", + "field_type": fields.IntField, + "db_column": "", + "db_field_types": {"": "INT"}, + "python_type": int, + "generated": True, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": None, + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + } + - def test_describe_models_all_not_serializable(self): - val = Tortoise.describe_models(serializable=False) - with self.assertRaisesRegex(TypeError, "not JSON serializable"): - json.dumps(val) - self.assertIn("models.SourceFields", val.keys()) - self.assertIn("models.Event", val.keys()) +def test_describe_relfield_noninit_ser(): + field = fields.ForeignKeyField("a.b") + assert field.describe(serializable=True) == { + "name": "", + "field_type": "ForeignKeyFieldInstance", + "python_type": "None", + "generated": False, + "nullable": False, + "on_delete": "CASCADE", + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + "db_constraint": True, + "raw_field": None, + } -class TestDescribeModel(test.SimpleTestCase): - maxDiff = None +def test_describe_relfield_noninit(): + field = fields.ForeignKeyField("a.b") + assert field.describe(serializable=False) == { + "name": "", + "field_type": ForeignKeyFieldInstance, + "python_type": None, + "generated": False, + "nullable": False, + "on_delete": "CASCADE", + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + "raw_field": None, + "db_constraint": True, + } - def test_describe_field_noninit_ser(self): - field = fields.IntField(primary_key=True) - self.assertEqual( - field.describe(serializable=True), + +def test_describe_models_some(): + val = Tortoise.describe_models([Event, Tournament, Reporter, Team]) + assert {"models.Event", "models.Tournament", "models.Reporter", "models.Team"} == set( + val.keys() + ) + + +def test_describe_model_straight(): + val = StraightFields.describe() + assert val == { + "name": "models.StraightFields", + "app": "models", + "table": "straightfields", + "abstract": False, + "description": "Straight auto-mapped fields", + "docstring": None, + "unique_together": [["chars", "blip"]], + "indexes": [], + "pk_field": { + "name": "eyedee", + "field_type": "IntField", + "db_column": "eyedee", + "db_field_types": {"": "INT"}, + "python_type": "int", + "generated": True, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": "Da PK", + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + }, + "data_fields": [ + { + "name": "chars", + "field_type": "CharField", + "db_column": "chars", + "db_field_types": { + "": "VARCHAR(50)", + "oracle": "NVARCHAR2(50)", + }, + "python_type": "str", + "generated": False, + "nullable": False, + "unique": False, + "indexed": True, + "default": None, + "description": "Some chars", + "docstring": None, + "constraints": {"max_length": 50}, + }, + { + "name": "blip", + "field_type": "CharField", + "db_column": "blip", + "db_field_types": { + "": "VARCHAR(50)", + "oracle": "NVARCHAR2(50)", + }, + "python_type": "str", + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": "BLIP", + "description": None, + "docstring": None, + "constraints": {"max_length": 50}, + }, + { + "name": "nullable", + "field_type": "CharField", + "db_column": "nullable", + "db_field_types": { + "": "VARCHAR(50)", + "oracle": "NVARCHAR2(50)", + }, + "python_type": "str", + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {"max_length": 50}, + }, { - "name": "", + "name": "fk_id", "field_type": "IntField", - "db_column": "", + "db_column": "fk_id", + "db_field_types": {"": "INT"}, + "python_type": "int", + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": "Tree!", + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + }, + { + "db_column": "o2o_id", "db_field_types": {"": "INT"}, + "default": None, + "description": "Line", + "docstring": None, + "field_type": "IntField", + "generated": False, + "indexed": True, + "name": "o2o_id", + "nullable": True, "python_type": "int", - "generated": True, + "unique": True, + "constraints": {"ge": -2147483648, "le": 2147483647}, + }, + ], + "fk_fields": [ + { + "name": "fk", + "field_type": "ForeignKeyFieldInstance", + "raw_field": "fk_id", + "python_type": "models.StraightFields", + "generated": False, + "nullable": True, + "unique": False, + "on_delete": "NO ACTION", + "indexed": False, + "default": None, + "description": "Tree!", + "docstring": None, + "db_constraint": True, + "constraints": {}, + } + ], + "backward_fk_fields": [ + { + "name": "fkrev", + "field_type": "BackwardFKRelation", + "python_type": "models.StraightFields", + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": "Tree!", + "docstring": None, + "constraints": {}, + "db_constraint": True, + } + ], + "o2o_fields": [ + { + "default": None, + "description": "Line", + "docstring": None, + "field_type": "OneToOneFieldInstance", + "generated": False, + "indexed": True, + "name": "o2o", + "nullable": True, + "on_delete": "NO ACTION", + "python_type": "models.StraightFields", + "raw_field": "o2o_id", + "unique": True, + "constraints": {}, + "db_constraint": True, + } + ], + "backward_o2o_fields": [ + { + "default": None, + "description": "Line", + "docstring": None, + "field_type": "BackwardOneToOneRelation", + "generated": False, + "indexed": False, + "name": "o2o_rev", + "nullable": True, + "python_type": "models.StraightFields", + "unique": False, + "constraints": {}, + "db_constraint": True, + } + ], + "m2m_fields": [ + { + "name": "rel_to", + "field_type": "ManyToManyFieldInstance", + "python_type": "models.StraightFields", + "generated": False, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": "M2M to myself", + "docstring": None, + "constraints": {}, + "model_name": "models.StraightFields", + "related_name": "rel_from", + "forward_key": "straightfields_id", + "backward_key": "straightfields_rel_id", + "through": "straightfields_straightfields", + "on_delete": "NO ACTION", + "_generated": False, + "db_constraint": True, + }, + { + "name": "rel_from", + "field_type": "ManyToManyFieldInstance", + "python_type": "models.StraightFields", + "generated": False, "nullable": False, "unique": True, "indexed": True, "default": None, + "description": "M2M to myself", + "docstring": None, + "constraints": {}, + "model_name": "models.StraightFields", + "related_name": "rel_to", + "forward_key": "straightfields_rel_id", + "backward_key": "straightfields_id", + "through": "straightfields_straightfields", + "on_delete": "CASCADE", + "db_constraint": True, + "_generated": True, + }, + ], + } + + +def test_describe_model_straight_native(): + val = StraightFields.describe(serializable=False) + assert val == { + "name": "models.StraightFields", + "app": "models", + "table": "straightfields", + "abstract": False, + "description": "Straight auto-mapped fields", + "docstring": None, + "unique_together": [["chars", "blip"]], + "indexes": [], + "pk_field": { + "name": "eyedee", + "field_type": fields.IntField, + "db_column": "eyedee", + "db_field_types": {"": "INT"}, + "python_type": int, + "generated": True, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": "Da PK", + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + }, + "data_fields": [ + { + "name": "chars", + "field_type": fields.CharField, + "db_column": "chars", + "db_field_types": { + "": "VARCHAR(50)", + "oracle": "NVARCHAR2(50)", + }, + "python_type": str, + "generated": False, + "nullable": False, + "unique": False, + "indexed": True, + "default": None, + "description": "Some chars", + "docstring": None, + "constraints": {"max_length": 50}, + }, + { + "name": "blip", + "field_type": fields.CharField, + "db_column": "blip", + "db_field_types": { + "": "VARCHAR(50)", + "oracle": "NVARCHAR2(50)", + }, + "python_type": str, + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": "BLIP", + "description": None, + "docstring": None, + "constraints": {"max_length": 50}, + }, + { + "name": "nullable", + "field_type": fields.CharField, + "db_column": "nullable", + "db_field_types": { + "": "VARCHAR(50)", + "oracle": "NVARCHAR2(50)", + }, + "python_type": str, + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, "description": None, "docstring": None, + "constraints": {"max_length": 50}, + }, + { + "name": "fk_id", + "field_type": fields.IntField, + "db_column": "fk_id", + "db_field_types": {"": "INT"}, + "python_type": int, + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": "Tree!", + "docstring": None, "constraints": {"ge": -2147483648, "le": 2147483647}, }, - ) - - def test_describe_field_noninit(self): - field = fields.IntField(primary_key=True) - self.assertEqual( - field.describe(serializable=False), { - "name": "", + "name": "o2o_id", "field_type": fields.IntField, - "db_column": "", + "db_column": "o2o_id", "db_field_types": {"": "INT"}, "python_type": int, - "generated": True, - "nullable": False, + "generated": False, + "nullable": True, "unique": True, "indexed": True, "default": None, - "description": None, + "description": "Line", "docstring": None, "constraints": {"ge": -2147483648, "le": 2147483647}, }, - ) - - def test_describe_relfield_noninit_ser(self): - field = fields.ForeignKeyField("a.b") - self.assertEqual( - field.describe(serializable=True), + ], + "fk_fields": [ { - "name": "", - "field_type": "ForeignKeyFieldInstance", - "python_type": "None", + "name": "fk", + "field_type": ForeignKeyFieldInstance, + "raw_field": "fk_id", + "python_type": StraightFields, "generated": False, - "nullable": False, - "on_delete": "CASCADE", + "nullable": True, + "on_delete": "NO ACTION", "unique": False, "indexed": False, "default": None, - "description": None, + "description": "Tree!", "docstring": None, "constraints": {}, "db_constraint": True, - "raw_field": None, - }, - ) - - def test_describe_relfield_noninit(self): - field = fields.ForeignKeyField("a.b") - self.assertEqual( - field.describe(serializable=False), + } + ], + "backward_fk_fields": [ { - "name": "", - "field_type": ForeignKeyFieldInstance, - "python_type": None, + "name": "fkrev", + "field_type": BackwardFKRelation, + "python_type": StraightFields, "generated": False, - "nullable": False, - "on_delete": "CASCADE", + "nullable": True, "unique": False, "indexed": False, "default": None, - "description": None, + "description": "Tree!", "docstring": None, "constraints": {}, - "raw_field": None, "db_constraint": True, - }, - ) - - def test_describe_models_some(self): - val = Tortoise.describe_models([Event, Tournament, Reporter, Team]) - self.assertEqual( - {"models.Event", "models.Tournament", "models.Reporter", "models.Team"}, set(val.keys()) - ) - - def test_describe_model_straight(self): - val = StraightFields.describe() - self.assertEqual( - val, + } + ], + "o2o_fields": [ { - "name": "models.StraightFields", - "app": "models", - "table": "straightfields", - "abstract": False, - "description": "Straight auto-mapped fields", + "default": None, + "description": "Line", "docstring": None, - "unique_together": [["chars", "blip"]], - "indexes": [], - "pk_field": { - "name": "eyedee", - "field_type": "IntField", - "db_column": "eyedee", - "db_field_types": {"": "INT"}, - "python_type": "int", - "generated": True, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": "Da PK", - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - "data_fields": [ - { - "name": "chars", - "field_type": "CharField", - "db_column": "chars", - "db_field_types": { - "": "VARCHAR(50)", - "oracle": "NVARCHAR2(50)", - }, - "python_type": "str", - "generated": False, - "nullable": False, - "unique": False, - "indexed": True, - "default": None, - "description": "Some chars", - "docstring": None, - "constraints": {"max_length": 50}, - }, - { - "name": "blip", - "field_type": "CharField", - "db_column": "blip", - "db_field_types": { - "": "VARCHAR(50)", - "oracle": "NVARCHAR2(50)", - }, - "python_type": "str", - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": "BLIP", - "description": None, - "docstring": None, - "constraints": {"max_length": 50}, - }, - { - "name": "nullable", - "field_type": "CharField", - "db_column": "nullable", - "db_field_types": { - "": "VARCHAR(50)", - "oracle": "NVARCHAR2(50)", - }, - "python_type": "str", - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {"max_length": 50}, - }, - { - "name": "fk_id", - "field_type": "IntField", - "db_column": "fk_id", - "db_field_types": {"": "INT"}, - "python_type": "int", - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": "Tree!", - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - { - "db_column": "o2o_id", - "db_field_types": {"": "INT"}, - "default": None, - "description": "Line", - "docstring": None, - "field_type": "IntField", - "generated": False, - "indexed": True, - "name": "o2o_id", - "nullable": True, - "python_type": "int", - "unique": True, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - ], - "fk_fields": [ - { - "name": "fk", - "field_type": "ForeignKeyFieldInstance", - "raw_field": "fk_id", - "python_type": "models.StraightFields", - "generated": False, - "nullable": True, - "unique": False, - "on_delete": "NO ACTION", - "indexed": False, - "default": None, - "description": "Tree!", - "docstring": None, - "db_constraint": True, - "constraints": {}, - } - ], - "backward_fk_fields": [ - { - "name": "fkrev", - "field_type": "BackwardFKRelation", - "python_type": "models.StraightFields", - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": "Tree!", - "docstring": None, - "constraints": {}, - "db_constraint": True, - } - ], - "o2o_fields": [ - { - "default": None, - "description": "Line", - "docstring": None, - "field_type": "OneToOneFieldInstance", - "generated": False, - "indexed": True, - "name": "o2o", - "nullable": True, - "on_delete": "NO ACTION", - "python_type": "models.StraightFields", - "raw_field": "o2o_id", - "unique": True, - "constraints": {}, - "db_constraint": True, - } - ], - "backward_o2o_fields": [ - { - "default": None, - "description": "Line", - "docstring": None, - "field_type": "BackwardOneToOneRelation", - "generated": False, - "indexed": False, - "name": "o2o_rev", - "nullable": True, - "python_type": "models.StraightFields", - "unique": False, - "constraints": {}, - "db_constraint": True, - } - ], - "m2m_fields": [ - { - "name": "rel_to", - "field_type": "ManyToManyFieldInstance", - "python_type": "models.StraightFields", - "generated": False, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": "M2M to myself", - "docstring": None, - "constraints": {}, - "model_name": "models.StraightFields", - "related_name": "rel_from", - "forward_key": "straightfields_id", - "backward_key": "straightfields_rel_id", - "through": "straightfields_straightfields", - "on_delete": "NO ACTION", - "_generated": False, - "db_constraint": True, - }, - { - "name": "rel_from", - "field_type": "ManyToManyFieldInstance", - "python_type": "models.StraightFields", - "generated": False, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": "M2M to myself", - "docstring": None, - "constraints": {}, - "model_name": "models.StraightFields", - "related_name": "rel_to", - "forward_key": "straightfields_rel_id", - "backward_key": "straightfields_id", - "through": "straightfields_straightfields", - "on_delete": "CASCADE", - "db_constraint": True, - "_generated": True, - }, - ], + "field_type": OneToOneFieldInstance, + "generated": False, + "indexed": True, + "name": "o2o", + "nullable": True, + "on_delete": "NO ACTION", + "python_type": StraightFields, + "raw_field": "o2o_id", + "unique": True, + "constraints": {}, + "db_constraint": True, }, - ) - - def test_describe_model_straight_native(self): - val = StraightFields.describe(serializable=False) - self.assertEqual( - val, + ], + "backward_o2o_fields": [ { - "name": "models.StraightFields", - "app": "models", - "table": "straightfields", - "abstract": False, - "description": "Straight auto-mapped fields", + "default": None, + "description": "Line", "docstring": None, - "unique_together": [["chars", "blip"]], - "indexes": [], - "pk_field": { - "name": "eyedee", - "field_type": fields.IntField, - "db_column": "eyedee", - "db_field_types": {"": "INT"}, - "python_type": int, - "generated": True, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": "Da PK", - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - "data_fields": [ - { - "name": "chars", - "field_type": fields.CharField, - "db_column": "chars", - "db_field_types": { - "": "VARCHAR(50)", - "oracle": "NVARCHAR2(50)", - }, - "python_type": str, - "generated": False, - "nullable": False, - "unique": False, - "indexed": True, - "default": None, - "description": "Some chars", - "docstring": None, - "constraints": {"max_length": 50}, - }, - { - "name": "blip", - "field_type": fields.CharField, - "db_column": "blip", - "db_field_types": { - "": "VARCHAR(50)", - "oracle": "NVARCHAR2(50)", - }, - "python_type": str, - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": "BLIP", - "description": None, - "docstring": None, - "constraints": {"max_length": 50}, - }, - { - "name": "nullable", - "field_type": fields.CharField, - "db_column": "nullable", - "db_field_types": { - "": "VARCHAR(50)", - "oracle": "NVARCHAR2(50)", - }, - "python_type": str, - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {"max_length": 50}, - }, - { - "name": "fk_id", - "field_type": fields.IntField, - "db_column": "fk_id", - "db_field_types": {"": "INT"}, - "python_type": int, - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": "Tree!", - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - { - "name": "o2o_id", - "field_type": fields.IntField, - "db_column": "o2o_id", - "db_field_types": {"": "INT"}, - "python_type": int, - "generated": False, - "nullable": True, - "unique": True, - "indexed": True, - "default": None, - "description": "Line", - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - ], - "fk_fields": [ - { - "name": "fk", - "field_type": ForeignKeyFieldInstance, - "raw_field": "fk_id", - "python_type": StraightFields, - "generated": False, - "nullable": True, - "on_delete": "NO ACTION", - "unique": False, - "indexed": False, - "default": None, - "description": "Tree!", - "docstring": None, - "constraints": {}, - "db_constraint": True, - } - ], - "backward_fk_fields": [ - { - "name": "fkrev", - "field_type": BackwardFKRelation, - "python_type": StraightFields, - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": "Tree!", - "docstring": None, - "constraints": {}, - "db_constraint": True, - } - ], - "o2o_fields": [ - { - "default": None, - "description": "Line", - "docstring": None, - "field_type": OneToOneFieldInstance, - "generated": False, - "indexed": True, - "name": "o2o", - "nullable": True, - "on_delete": "NO ACTION", - "python_type": StraightFields, - "raw_field": "o2o_id", - "unique": True, - "constraints": {}, - "db_constraint": True, - }, - ], - "backward_o2o_fields": [ - { - "default": None, - "description": "Line", - "docstring": None, - "field_type": fields.BackwardOneToOneRelation, - "generated": False, - "indexed": False, - "name": "o2o_rev", - "nullable": True, - "python_type": StraightFields, - "unique": False, - "constraints": {}, - "db_constraint": True, - }, - ], - "m2m_fields": [ - { - "name": "rel_to", - "field_type": ManyToManyFieldInstance, - "python_type": StraightFields, - "generated": False, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": "M2M to myself", - "docstring": None, - "constraints": {}, - "model_name": "models.StraightFields", - "related_name": "rel_from", - "forward_key": "straightfields_id", - "backward_key": "straightfields_rel_id", - "through": "straightfields_straightfields", - "on_delete": "NO ACTION", - "_generated": False, - "db_constraint": True, - }, - { - "name": "rel_from", - "field_type": ManyToManyFieldInstance, - "python_type": StraightFields, - "generated": False, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": "M2M to myself", - "docstring": None, - "constraints": {}, - "model_name": "models.StraightFields", - "related_name": "rel_to", - "forward_key": "straightfields_rel_id", - "backward_key": "straightfields_id", - "through": "straightfields_straightfields", - "on_delete": "CASCADE", - "_generated": True, - "db_constraint": True, - }, - ], + "field_type": fields.BackwardOneToOneRelation, + "generated": False, + "indexed": False, + "name": "o2o_rev", + "nullable": True, + "python_type": StraightFields, + "unique": False, + "constraints": {}, + "db_constraint": True, }, - ) - - def test_describe_model_source(self): - val = SourceFields.describe() - self.assertEqual( - val, + ], + "m2m_fields": [ { - "name": "models.SourceFields", - "app": "models", - "table": "sometable", - "abstract": False, - "description": "Source mapped fields", - "docstring": "A Docstring.", - "unique_together": [["chars", "blip"]], - "indexes": [], - "pk_field": { - "name": "eyedee", - "field_type": "IntField", - "db_column": "sometable_id", - "db_field_types": {"": "INT"}, - "python_type": "int", - "generated": True, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": "Da PK", - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - "data_fields": [ - { - "name": "chars", - "field_type": "CharField", - "db_column": "some_chars_table", - "db_field_types": { - "": "VARCHAR(50)", - "oracle": "NVARCHAR2(50)", - }, - "python_type": "str", - "generated": False, - "nullable": False, - "unique": False, - "indexed": True, - "default": None, - "description": "Some chars", - "docstring": None, - "constraints": {"max_length": 50}, - }, - { - "name": "blip", - "field_type": "CharField", - "db_column": "da_blip", - "db_field_types": { - "": "VARCHAR(50)", - "oracle": "NVARCHAR2(50)", - }, - "python_type": "str", - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": "BLIP", - "description": "A docstring comment", - "docstring": "A docstring comment", - "constraints": {"max_length": 50}, - }, - { - "name": "nullable", - "field_type": "CharField", - "db_column": "some_nullable", - "db_field_types": { - "": "VARCHAR(50)", - "oracle": "NVARCHAR2(50)", - }, - "python_type": "str", - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {"max_length": 50}, - }, - { - "name": "fk_id", - "field_type": "IntField", - "db_column": "fk_sometable", - "db_field_types": {"": "INT"}, - "python_type": "int", - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": "Tree!", - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - { - "name": "o2o_id", - "field_type": "IntField", - "db_column": "o2o_sometable", - "db_field_types": {"": "INT"}, - "python_type": "int", - "generated": False, - "nullable": True, - "unique": True, - "indexed": True, - "default": None, - "description": "Line", - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - ], - "fk_fields": [ - { - "name": "fk", - "field_type": "ForeignKeyFieldInstance", - "raw_field": "fk_id", - "python_type": "models.SourceFields", - "generated": False, - "nullable": True, - "on_delete": "NO ACTION", - "unique": False, - "indexed": False, - "default": None, - "description": "Tree!", - "docstring": None, - "constraints": {}, - "db_constraint": True, - } - ], - "backward_fk_fields": [ - { - "name": "fkrev", - "field_type": "BackwardFKRelation", - "python_type": "models.SourceFields", - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": "Tree!", - "docstring": None, - "constraints": {}, - "db_constraint": True, - } - ], - "o2o_fields": [ - { - "default": None, - "description": "Line", - "docstring": None, - "field_type": "OneToOneFieldInstance", - "generated": False, - "indexed": True, - "name": "o2o", - "nullable": True, - "on_delete": "NO ACTION", - "python_type": "models.SourceFields", - "raw_field": "o2o_id", - "unique": True, - "constraints": {}, - "db_constraint": True, - } - ], - "backward_o2o_fields": [ - { - "default": None, - "description": "Line", - "docstring": None, - "field_type": "BackwardOneToOneRelation", - "generated": False, - "indexed": False, - "name": "o2o_rev", - "nullable": True, - "python_type": "models.SourceFields", - "unique": False, - "constraints": {}, - "db_constraint": True, - } - ], - "m2m_fields": [ - { - "name": "rel_to", - "field_type": "ManyToManyFieldInstance", - "python_type": "models.SourceFields", - "generated": False, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": "M2M to myself", - "docstring": None, - "constraints": {}, - "model_name": "models.SourceFields", - "related_name": "rel_from", - "forward_key": "sts_forward", - "backward_key": "backward_sts", - "through": "sometable_self", - "on_delete": "NO ACTION", - "_generated": False, - "db_constraint": True, - }, - { - "name": "rel_from", - "field_type": "ManyToManyFieldInstance", - "python_type": "models.SourceFields", - "generated": False, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": "M2M to myself", - "docstring": None, - "constraints": {}, - "model_name": "models.SourceFields", - "related_name": "rel_to", - "forward_key": "backward_sts", - "backward_key": "sts_forward", - "through": "sometable_self", - "on_delete": "CASCADE", - "_generated": True, - "db_constraint": True, - }, - ], + "name": "rel_to", + "field_type": ManyToManyFieldInstance, + "python_type": StraightFields, + "generated": False, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": "M2M to myself", + "docstring": None, + "constraints": {}, + "model_name": "models.StraightFields", + "related_name": "rel_from", + "forward_key": "straightfields_id", + "backward_key": "straightfields_rel_id", + "through": "straightfields_straightfields", + "on_delete": "NO ACTION", + "_generated": False, + "db_constraint": True, }, - ) - - def test_describe_model_source_native(self): - val = SourceFields.describe(serializable=False) - self.assertEqual( - val, { - "name": "models.SourceFields", - "app": "models", - "table": "sometable", - "abstract": False, - "description": "Source mapped fields", - "docstring": "A Docstring.", - "unique_together": [["chars", "blip"]], - "indexes": [], - "pk_field": { - "name": "eyedee", - "field_type": fields.IntField, - "db_column": "sometable_id", - "db_field_types": {"": "INT"}, - "python_type": int, - "generated": True, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": "Da PK", - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - "data_fields": [ - { - "name": "chars", - "field_type": fields.CharField, - "db_column": "some_chars_table", - "db_field_types": { - "": "VARCHAR(50)", - "oracle": "NVARCHAR2(50)", - }, - "python_type": str, - "generated": False, - "nullable": False, - "unique": False, - "indexed": True, - "default": None, - "description": "Some chars", - "docstring": None, - "constraints": {"max_length": 50}, - }, - { - "name": "blip", - "field_type": fields.CharField, - "db_column": "da_blip", - "db_field_types": { - "": "VARCHAR(50)", - "oracle": "NVARCHAR2(50)", - }, - "python_type": str, - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": "BLIP", - "description": "A docstring comment", - "docstring": "A docstring comment", - "constraints": {"max_length": 50}, - }, - { - "name": "nullable", - "field_type": fields.CharField, - "db_column": "some_nullable", - "db_field_types": { - "": "VARCHAR(50)", - "oracle": "NVARCHAR2(50)", - }, - "python_type": str, - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {"max_length": 50}, - }, - { - "name": "fk_id", - "field_type": fields.IntField, - "db_column": "fk_sometable", - "db_field_types": {"": "INT"}, - "python_type": int, - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": "Tree!", - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - { - "name": "o2o_id", - "field_type": fields.IntField, - "db_column": "o2o_sometable", - "db_field_types": {"": "INT"}, - "python_type": int, - "generated": False, - "nullable": True, - "unique": True, - "indexed": True, - "default": None, - "description": "Line", - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - ], - "fk_fields": [ - { - "name": "fk", - "field_type": ForeignKeyFieldInstance, - "raw_field": "fk_id", - "python_type": SourceFields, - "generated": False, - "nullable": True, - "unique": False, - "on_delete": "NO ACTION", - "indexed": False, - "default": None, - "description": "Tree!", - "docstring": None, - "constraints": {}, - "db_constraint": True, - } - ], - "backward_fk_fields": [ - { - "name": "fkrev", - "field_type": BackwardFKRelation, - "python_type": SourceFields, - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": "Tree!", - "docstring": None, - "constraints": {}, - "db_constraint": True, - } - ], - "o2o_fields": [ - { - "default": None, - "description": "Line", - "docstring": None, - "field_type": OneToOneFieldInstance, - "generated": False, - "indexed": True, - "name": "o2o", - "nullable": True, - "on_delete": "NO ACTION", - "python_type": SourceFields, - "raw_field": "o2o_id", - "unique": True, - "constraints": {}, - "db_constraint": True, - } - ], - "backward_o2o_fields": [ - { - "default": None, - "description": "Line", - "docstring": None, - "field_type": fields.BackwardOneToOneRelation, - "generated": False, - "indexed": False, - "name": "o2o_rev", - "nullable": True, - "python_type": SourceFields, - "unique": False, - "constraints": {}, - "db_constraint": True, - } - ], - "m2m_fields": [ - { - "name": "rel_to", - "field_type": ManyToManyFieldInstance, - "python_type": SourceFields, - "generated": False, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": "M2M to myself", - "docstring": None, - "constraints": {}, - "model_name": "models.SourceFields", - "related_name": "rel_from", - "forward_key": "sts_forward", - "backward_key": "backward_sts", - "through": "sometable_self", - "on_delete": "NO ACTION", - "_generated": False, - "db_constraint": True, - }, - { - "name": "rel_from", - "field_type": ManyToManyFieldInstance, - "python_type": SourceFields, - "generated": False, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": "M2M to myself", - "docstring": None, - "constraints": {}, - "model_name": "models.SourceFields", - "related_name": "rel_to", - "forward_key": "backward_sts", - "backward_key": "sts_forward", - "through": "sometable_self", - "on_delete": "CASCADE", - "_generated": True, - "db_constraint": True, - }, - ], + "name": "rel_from", + "field_type": ManyToManyFieldInstance, + "python_type": StraightFields, + "generated": False, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": "M2M to myself", + "docstring": None, + "constraints": {}, + "model_name": "models.StraightFields", + "related_name": "rel_to", + "forward_key": "straightfields_rel_id", + "backward_key": "straightfields_id", + "through": "straightfields_straightfields", + "on_delete": "CASCADE", + "_generated": True, + "db_constraint": True, }, - ) + ], + } - def test_describe_model_uuidpk(self): - val = UUIDPkModel.describe() - self.assertEqual( - val, +def test_describe_model_source(): + val = SourceFields.describe() + assert val == { + "name": "models.SourceFields", + "app": "models", + "table": "sometable", + "abstract": False, + "description": "Source mapped fields", + "docstring": "A Docstring.", + "unique_together": [["chars", "blip"]], + "indexes": [], + "pk_field": { + "name": "eyedee", + "field_type": "IntField", + "db_column": "sometable_id", + "db_field_types": {"": "INT"}, + "python_type": "int", + "generated": True, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": "Da PK", + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + }, + "data_fields": [ { - "name": "models.UUIDPkModel", - "app": "models", - "table": "uuidpkmodel", - "abstract": False, - "description": None, - "docstring": None, - "unique_together": [], - "indexes": [], - "pk_field": { - "name": "id", - "field_type": "UUIDField", - "db_column": "id", - "db_field_types": {"": "CHAR(36)", "postgres": "UUID"}, - "python_type": "uuid.UUID", - "generated": False, - "nullable": False, - "unique": True, - "indexed": True, - "default": "", - "description": None, - "docstring": None, - "constraints": {}, + "name": "chars", + "field_type": "CharField", + "db_column": "some_chars_table", + "db_field_types": { + "": "VARCHAR(50)", + "oracle": "NVARCHAR2(50)", }, - "data_fields": [], - "fk_fields": [], - "backward_fk_fields": [ - { - "name": "children", - "field_type": "BackwardFKRelation", - "python_type": "models.UUIDFkRelatedModel", - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {}, - "db_constraint": True, - }, - ], - "o2o_fields": [], - "backward_o2o_fields": [], - "m2m_fields": [ - { - "_generated": True, - "backward_key": "uuidpkmodel_id", - "constraints": {}, - "db_constraint": True, - "default": None, - "description": None, - "docstring": None, - "field_type": "ManyToManyFieldInstance", - "forward_key": "uuidm2mrelatedmodel_id", - "generated": False, - "indexed": True, - "model_name": "models.UUIDM2MRelatedModel", - "name": "peers", - "nullable": False, - "on_delete": "CASCADE", - "python_type": "models.UUIDM2MRelatedModel", - "related_name": "models", - "through": "uuidm2mrelatedmodel_uuidpkmodel", - "unique": True, - } - ], + "python_type": "str", + "generated": False, + "nullable": False, + "unique": False, + "indexed": True, + "default": None, + "description": "Some chars", + "docstring": None, + "constraints": {"max_length": 50}, }, - ) - - def test_describe_model_uuidpk_native(self): - val = UUIDPkModel.describe(serializable=False) - self.assertEqual( - val, { - "name": "models.UUIDPkModel", - "app": "models", - "table": "uuidpkmodel", - "abstract": False, - "description": None, - "docstring": None, - "unique_together": [], - "indexes": [], - "pk_field": { - "name": "id", - "field_type": fields.UUIDField, - "db_column": "id", - "db_field_types": {"": "CHAR(36)", "postgres": "UUID"}, - "python_type": uuid.UUID, - "generated": False, - "nullable": False, - "unique": True, - "indexed": True, - "default": uuid.uuid4, - "description": None, - "docstring": None, - "constraints": {}, + "name": "blip", + "field_type": "CharField", + "db_column": "da_blip", + "db_field_types": { + "": "VARCHAR(50)", + "oracle": "NVARCHAR2(50)", }, - "data_fields": [], - "fk_fields": [], - "backward_fk_fields": [ - { - "name": "children", - "field_type": BackwardFKRelation, - "python_type": UUIDFkRelatedModel, - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {}, - "db_constraint": True, - }, - ], - "o2o_fields": [], - "backward_o2o_fields": [], - "m2m_fields": [ - { - "name": "peers", - "db_constraint": True, - "generated": False, - "nullable": False, - "field_type": ManyToManyFieldInstance, - "python_type": UUIDM2MRelatedModel, - "unique": True, - "indexed": True, - "default": None, - "description": None, - "docstring": None, - "constraints": {}, - "model_name": "models.UUIDM2MRelatedModel", - "related_name": "models", - "forward_key": "uuidm2mrelatedmodel_id", - "backward_key": "uuidpkmodel_id", - "through": "uuidm2mrelatedmodel_uuidpkmodel", - "on_delete": "CASCADE", - "_generated": True, - } - ], + "python_type": "str", + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": "BLIP", + "description": "A docstring comment", + "docstring": "A docstring comment", + "constraints": {"max_length": 50}, }, - ) - - def test_describe_model_uuidpk_relatednull(self): - val = UUIDFkRelatedNullModel.describe(serializable=True) - - self.assertEqual( - val, { - "abstract": False, - "app": "models", - "backward_fk_fields": [], - "backward_o2o_fields": [], - "data_fields": [ - { - "db_column": "name", - "db_field_types": { - "": "VARCHAR(50)", - "oracle": "NVARCHAR2(50)", - }, - "default": None, - "description": None, - "docstring": None, - "field_type": "CharField", - "generated": False, - "indexed": False, - "name": "name", - "nullable": True, - "python_type": "str", - "unique": False, - "constraints": {"max_length": 50}, - }, - { - "db_column": "model_id", - "db_field_types": {"": "CHAR(36)", "postgres": "UUID"}, - "default": None, - "description": None, - "docstring": None, - "field_type": "UUIDField", - "generated": False, - "indexed": False, - "name": "model_id", - "nullable": True, - "python_type": "uuid.UUID", - "unique": False, - "constraints": {}, - }, - { - "db_column": "parent_id", - "db_field_types": {"": "CHAR(36)", "postgres": "UUID"}, - "default": None, - "description": None, - "docstring": None, - "field_type": "UUIDField", - "generated": False, - "indexed": True, - "name": "parent_id", - "nullable": True, - "python_type": "uuid.UUID", - "unique": True, - "constraints": {}, - }, - ], + "name": "nullable", + "field_type": "CharField", + "db_column": "some_nullable", + "db_field_types": { + "": "VARCHAR(50)", + "oracle": "NVARCHAR2(50)", + }, + "python_type": "str", + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, "description": None, "docstring": None, - "fk_fields": [ - { - "default": None, - "description": None, - "docstring": None, - "field_type": "ForeignKeyFieldInstance", - "generated": False, - "indexed": False, - "name": "model", - "nullable": True, - "on_delete": "CASCADE", - "python_type": "models.UUIDPkModel", - "raw_field": "model_id", - "unique": False, - "constraints": {}, - "db_constraint": True, - } - ], - "m2m_fields": [], - "name": "models.UUIDFkRelatedNullModel", - "o2o_fields": [ - { - "default": None, - "description": None, - "docstring": None, - "field_type": "OneToOneFieldInstance", - "generated": False, - "indexed": True, - "name": "parent", - "nullable": True, - "on_delete": "NO ACTION", - "python_type": "models.UUIDPkModel", - "raw_field": "parent_id", - "unique": True, - "constraints": {}, - "db_constraint": True, - } - ], - "pk_field": { - "db_column": "id", - "db_field_types": {"": "CHAR(36)", "postgres": "UUID"}, - "default": "", - "description": None, - "docstring": None, - "field_type": "UUIDField", - "generated": False, - "indexed": True, - "name": "id", - "nullable": False, - "python_type": "uuid.UUID", - "unique": True, - "constraints": {}, - }, - "table": "uuidfkrelatednullmodel", - "unique_together": [], - "indexes": [], + "constraints": {"max_length": 50}, }, - ) - - def test_describe_model_json(self): - val = JSONFields.describe() - - self.assertEqual( - val, { - "name": "models.JSONFields", - "app": "models", - "table": "jsonfields", - "abstract": False, - "description": "This model contains many JSON blobs", - "docstring": "This model contains many JSON blobs", - "unique_together": [], - "indexes": [], - "pk_field": { - "name": "id", - "field_type": "IntField", - "db_column": "id", - "db_field_types": {"": "INT"}, - "python_type": "int", - "generated": True, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": None, - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - "data_fields": [ - { - "name": "data", - "field_type": "JSONField", - "db_column": "data", - "db_field_types": { - "": "JSON", - "mssql": "NVARCHAR(MAX)", - "oracle": "NCLOB", - "postgres": "JSONB", - }, - "python_type": UNION_DICT_LIST, - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {}, - }, - { - "name": "data_null", - "field_type": "JSONField", - "db_column": "data_null", - "db_field_types": { - "": "JSON", - "mssql": "NVARCHAR(MAX)", - "oracle": "NCLOB", - "postgres": "JSONB", - }, - "python_type": UNION_DICT_LIST, - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {}, - }, - { - "name": "data_default", - "field_type": "JSONField", - "db_column": "data_default", - "db_field_types": { - "": "JSON", - "mssql": "NVARCHAR(MAX)", - "oracle": "NCLOB", - "postgres": "JSONB", - }, - "python_type": UNION_DICT_LIST, - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": "{'a': 1}", - "description": None, - "docstring": None, - "constraints": {}, - }, - { - "name": "data_validate", - "field_type": "JSONField", - "db_column": "data_validate", - "db_field_types": { - "": "JSON", - "mssql": "NVARCHAR(MAX)", - "oracle": "NCLOB", - "postgres": "JSONB", - }, - "python_type": UNION_DICT_LIST, - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {}, - }, - { - "name": "data_pydantic", - "field_type": "JSONField", - "db_column": "data_pydantic", - "db_field_types": { - "": "JSON", - "mssql": "NVARCHAR(MAX)", - "oracle": "NCLOB", - "postgres": "JSONB", - }, - "python_type": "tests.testmodels.TestSchemaForJSONField", - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": "foo=1 bar='baz'", - "description": None, - "docstring": None, - "constraints": {}, - }, - ], - "fk_fields": [], - "backward_fk_fields": [], - "o2o_fields": [], - "backward_o2o_fields": [], - "m2m_fields": [], + "name": "fk_id", + "field_type": "IntField", + "db_column": "fk_sometable", + "db_field_types": {"": "INT"}, + "python_type": "int", + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": "Tree!", + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, }, - ) - - def test_describe_model_json_native(self): - val = JSONFields.describe(serializable=False) - - self.assertEqual( - val, { - "name": "models.JSONFields", - "app": "models", - "table": "jsonfields", - "abstract": False, - "description": "This model contains many JSON blobs", - "docstring": "This model contains many JSON blobs", - "unique_together": [], - "indexes": [], - "pk_field": { - "name": "id", - "field_type": fields.IntField, - "db_column": "id", - "db_field_types": {"": "INT"}, - "python_type": int, - "generated": True, - "nullable": False, - "unique": True, - "indexed": True, - "default": None, - "description": None, - "docstring": None, - "constraints": {"ge": -2147483648, "le": 2147483647}, - }, - "data_fields": [ - { - "name": "data", - "field_type": fields.JSONField, - "db_column": "data", - "db_field_types": { - "": "JSON", - "mssql": "NVARCHAR(MAX)", - "oracle": "NCLOB", - "postgres": "JSONB", - }, - "python_type": dict | list, - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {}, - }, - { - "name": "data_null", - "field_type": fields.JSONField, - "db_column": "data_null", - "db_field_types": { - "": "JSON", - "mssql": "NVARCHAR(MAX)", - "oracle": "NCLOB", - "postgres": "JSONB", - }, - "python_type": dict | list, - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {}, - }, - { - "name": "data_default", - "field_type": fields.JSONField, - "db_column": "data_default", - "db_field_types": { - "": "JSON", - "mssql": "NVARCHAR(MAX)", - "oracle": "NCLOB", - "postgres": "JSONB", - }, - "python_type": dict | list, - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": {"a": 1}, - "description": None, - "docstring": None, - "constraints": {}, - }, - { - "name": "data_validate", - "field_type": fields.JSONField, - "db_column": "data_validate", - "db_field_types": { - "": "JSON", - "mssql": "NVARCHAR(MAX)", - "oracle": "NCLOB", - "postgres": "JSONB", - }, - "python_type": dict | list, - "generated": False, - "nullable": True, - "unique": False, - "indexed": False, - "default": None, - "description": None, - "docstring": None, - "constraints": {}, - }, - { - "name": "data_pydantic", - "field_type": fields.JSONField, - "db_column": "data_pydantic", - "db_field_types": { - "": "JSON", - "mssql": "NVARCHAR(MAX)", - "oracle": "NCLOB", - "postgres": "JSONB", - }, - "python_type": TestSchemaForJSONField, - "generated": False, - "nullable": False, - "unique": False, - "indexed": False, - "default": json_pydantic_default, - "description": None, - "docstring": None, - "constraints": {}, - }, - ], - "fk_fields": [], - "backward_fk_fields": [], - "o2o_fields": [], - "backward_o2o_fields": [], - "m2m_fields": [], + "name": "o2o_id", + "field_type": "IntField", + "db_column": "o2o_sometable", + "db_field_types": {"": "INT"}, + "python_type": "int", + "generated": False, + "nullable": True, + "unique": True, + "indexed": True, + "default": None, + "description": "Line", + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, }, - ) - - def test_describe_indexes_serializable(self): - val = ModelWithIndexes.describe() - - self.assertEqual( - val["indexes"], - [ - {"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""}, - { - "fields": ["f3"], - "expressions": [], - "name": "model_with_indexes__f3", - "type": "", - "extra": "", - }, - ], - ) + ], + "fk_fields": [ + { + "name": "fk", + "field_type": "ForeignKeyFieldInstance", + "raw_field": "fk_id", + "python_type": "models.SourceFields", + "generated": False, + "nullable": True, + "on_delete": "NO ACTION", + "unique": False, + "indexed": False, + "default": None, + "description": "Tree!", + "docstring": None, + "constraints": {}, + "db_constraint": True, + } + ], + "backward_fk_fields": [ + { + "name": "fkrev", + "field_type": "BackwardFKRelation", + "python_type": "models.SourceFields", + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": "Tree!", + "docstring": None, + "constraints": {}, + "db_constraint": True, + } + ], + "o2o_fields": [ + { + "default": None, + "description": "Line", + "docstring": None, + "field_type": "OneToOneFieldInstance", + "generated": False, + "indexed": True, + "name": "o2o", + "nullable": True, + "on_delete": "NO ACTION", + "python_type": "models.SourceFields", + "raw_field": "o2o_id", + "unique": True, + "constraints": {}, + "db_constraint": True, + } + ], + "backward_o2o_fields": [ + { + "default": None, + "description": "Line", + "docstring": None, + "field_type": "BackwardOneToOneRelation", + "generated": False, + "indexed": False, + "name": "o2o_rev", + "nullable": True, + "python_type": "models.SourceFields", + "unique": False, + "constraints": {}, + "db_constraint": True, + } + ], + "m2m_fields": [ + { + "name": "rel_to", + "field_type": "ManyToManyFieldInstance", + "python_type": "models.SourceFields", + "generated": False, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": "M2M to myself", + "docstring": None, + "constraints": {}, + "model_name": "models.SourceFields", + "related_name": "rel_from", + "forward_key": "sts_forward", + "backward_key": "backward_sts", + "through": "sometable_self", + "on_delete": "NO ACTION", + "_generated": False, + "db_constraint": True, + }, + { + "name": "rel_from", + "field_type": "ManyToManyFieldInstance", + "python_type": "models.SourceFields", + "generated": False, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": "M2M to myself", + "docstring": None, + "constraints": {}, + "model_name": "models.SourceFields", + "related_name": "rel_to", + "forward_key": "backward_sts", + "backward_key": "sts_forward", + "through": "sometable_self", + "on_delete": "CASCADE", + "_generated": True, + "db_constraint": True, + }, + ], + } + + +def test_describe_model_source_native(): + val = SourceFields.describe(serializable=False) + assert val == { + "name": "models.SourceFields", + "app": "models", + "table": "sometable", + "abstract": False, + "description": "Source mapped fields", + "docstring": "A Docstring.", + "unique_together": [["chars", "blip"]], + "indexes": [], + "pk_field": { + "name": "eyedee", + "field_type": fields.IntField, + "db_column": "sometable_id", + "db_field_types": {"": "INT"}, + "python_type": int, + "generated": True, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": "Da PK", + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + }, + "data_fields": [ + { + "name": "chars", + "field_type": fields.CharField, + "db_column": "some_chars_table", + "db_field_types": { + "": "VARCHAR(50)", + "oracle": "NVARCHAR2(50)", + }, + "python_type": str, + "generated": False, + "nullable": False, + "unique": False, + "indexed": True, + "default": None, + "description": "Some chars", + "docstring": None, + "constraints": {"max_length": 50}, + }, + { + "name": "blip", + "field_type": fields.CharField, + "db_column": "da_blip", + "db_field_types": { + "": "VARCHAR(50)", + "oracle": "NVARCHAR2(50)", + }, + "python_type": str, + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": "BLIP", + "description": "A docstring comment", + "docstring": "A docstring comment", + "constraints": {"max_length": 50}, + }, + { + "name": "nullable", + "field_type": fields.CharField, + "db_column": "some_nullable", + "db_field_types": { + "": "VARCHAR(50)", + "oracle": "NVARCHAR2(50)", + }, + "python_type": str, + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {"max_length": 50}, + }, + { + "name": "fk_id", + "field_type": fields.IntField, + "db_column": "fk_sometable", + "db_field_types": {"": "INT"}, + "python_type": int, + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": "Tree!", + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + }, + { + "name": "o2o_id", + "field_type": fields.IntField, + "db_column": "o2o_sometable", + "db_field_types": {"": "INT"}, + "python_type": int, + "generated": False, + "nullable": True, + "unique": True, + "indexed": True, + "default": None, + "description": "Line", + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + }, + ], + "fk_fields": [ + { + "name": "fk", + "field_type": ForeignKeyFieldInstance, + "raw_field": "fk_id", + "python_type": SourceFields, + "generated": False, + "nullable": True, + "unique": False, + "on_delete": "NO ACTION", + "indexed": False, + "default": None, + "description": "Tree!", + "docstring": None, + "constraints": {}, + "db_constraint": True, + } + ], + "backward_fk_fields": [ + { + "name": "fkrev", + "field_type": BackwardFKRelation, + "python_type": SourceFields, + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": "Tree!", + "docstring": None, + "constraints": {}, + "db_constraint": True, + } + ], + "o2o_fields": [ + { + "default": None, + "description": "Line", + "docstring": None, + "field_type": OneToOneFieldInstance, + "generated": False, + "indexed": True, + "name": "o2o", + "nullable": True, + "on_delete": "NO ACTION", + "python_type": SourceFields, + "raw_field": "o2o_id", + "unique": True, + "constraints": {}, + "db_constraint": True, + } + ], + "backward_o2o_fields": [ + { + "default": None, + "description": "Line", + "docstring": None, + "field_type": fields.BackwardOneToOneRelation, + "generated": False, + "indexed": False, + "name": "o2o_rev", + "nullable": True, + "python_type": SourceFields, + "unique": False, + "constraints": {}, + "db_constraint": True, + } + ], + "m2m_fields": [ + { + "name": "rel_to", + "field_type": ManyToManyFieldInstance, + "python_type": SourceFields, + "generated": False, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": "M2M to myself", + "docstring": None, + "constraints": {}, + "model_name": "models.SourceFields", + "related_name": "rel_from", + "forward_key": "sts_forward", + "backward_key": "backward_sts", + "through": "sometable_self", + "on_delete": "NO ACTION", + "_generated": False, + "db_constraint": True, + }, + { + "name": "rel_from", + "field_type": ManyToManyFieldInstance, + "python_type": SourceFields, + "generated": False, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": "M2M to myself", + "docstring": None, + "constraints": {}, + "model_name": "models.SourceFields", + "related_name": "rel_to", + "forward_key": "backward_sts", + "backward_key": "sts_forward", + "through": "sometable_self", + "on_delete": "CASCADE", + "_generated": True, + "db_constraint": True, + }, + ], + } + + +def test_describe_model_uuidpk(): + val = UUIDPkModel.describe() + + assert val == { + "name": "models.UUIDPkModel", + "app": "models", + "table": "uuidpkmodel", + "abstract": False, + "description": None, + "docstring": None, + "unique_together": [], + "indexes": [], + "pk_field": { + "name": "id", + "field_type": "UUIDField", + "db_column": "id", + "db_field_types": {"": "CHAR(36)", "postgres": "UUID"}, + "python_type": "uuid.UUID", + "generated": False, + "nullable": False, + "unique": True, + "indexed": True, + "default": "", + "description": None, + "docstring": None, + "constraints": {}, + }, + "data_fields": [], + "fk_fields": [], + "backward_fk_fields": [ + { + "name": "children", + "field_type": "BackwardFKRelation", + "python_type": "models.UUIDFkRelatedModel", + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + "db_constraint": True, + }, + ], + "o2o_fields": [], + "backward_o2o_fields": [], + "m2m_fields": [ + { + "_generated": True, + "backward_key": "uuidpkmodel_id", + "constraints": {}, + "db_constraint": True, + "default": None, + "description": None, + "docstring": None, + "field_type": "ManyToManyFieldInstance", + "forward_key": "uuidm2mrelatedmodel_id", + "generated": False, + "indexed": True, + "model_name": "models.UUIDM2MRelatedModel", + "name": "peers", + "nullable": False, + "on_delete": "CASCADE", + "python_type": "models.UUIDM2MRelatedModel", + "related_name": "models", + "through": "uuidm2mrelatedmodel_uuidpkmodel", + "unique": True, + } + ], + } + + +def test_describe_model_uuidpk_native(): + val = UUIDPkModel.describe(serializable=False) + assert val == { + "name": "models.UUIDPkModel", + "app": "models", + "table": "uuidpkmodel", + "abstract": False, + "description": None, + "docstring": None, + "unique_together": [], + "indexes": [], + "pk_field": { + "name": "id", + "field_type": fields.UUIDField, + "db_column": "id", + "db_field_types": {"": "CHAR(36)", "postgres": "UUID"}, + "python_type": uuid.UUID, + "generated": False, + "nullable": False, + "unique": True, + "indexed": True, + "default": uuid.uuid4, + "description": None, + "docstring": None, + "constraints": {}, + }, + "data_fields": [], + "fk_fields": [], + "backward_fk_fields": [ + { + "name": "children", + "field_type": BackwardFKRelation, + "python_type": UUIDFkRelatedModel, + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + "db_constraint": True, + }, + ], + "o2o_fields": [], + "backward_o2o_fields": [], + "m2m_fields": [ + { + "name": "peers", + "db_constraint": True, + "generated": False, + "nullable": False, + "field_type": ManyToManyFieldInstance, + "python_type": UUIDM2MRelatedModel, + "unique": True, + "indexed": True, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + "model_name": "models.UUIDM2MRelatedModel", + "related_name": "models", + "forward_key": "uuidm2mrelatedmodel_id", + "backward_key": "uuidpkmodel_id", + "through": "uuidm2mrelatedmodel_uuidpkmodel", + "on_delete": "CASCADE", + "_generated": True, + } + ], + } + + +def test_describe_model_uuidpk_relatednull(): + val = UUIDFkRelatedNullModel.describe(serializable=True) + + assert val == { + "abstract": False, + "app": "models", + "backward_fk_fields": [], + "backward_o2o_fields": [], + "data_fields": [ + { + "db_column": "name", + "db_field_types": { + "": "VARCHAR(50)", + "oracle": "NVARCHAR2(50)", + }, + "default": None, + "description": None, + "docstring": None, + "field_type": "CharField", + "generated": False, + "indexed": False, + "name": "name", + "nullable": True, + "python_type": "str", + "unique": False, + "constraints": {"max_length": 50}, + }, + { + "db_column": "model_id", + "db_field_types": {"": "CHAR(36)", "postgres": "UUID"}, + "default": None, + "description": None, + "docstring": None, + "field_type": "UUIDField", + "generated": False, + "indexed": False, + "name": "model_id", + "nullable": True, + "python_type": "uuid.UUID", + "unique": False, + "constraints": {}, + }, + { + "db_column": "parent_id", + "db_field_types": {"": "CHAR(36)", "postgres": "UUID"}, + "default": None, + "description": None, + "docstring": None, + "field_type": "UUIDField", + "generated": False, + "indexed": True, + "name": "parent_id", + "nullable": True, + "python_type": "uuid.UUID", + "unique": True, + "constraints": {}, + }, + ], + "description": None, + "docstring": None, + "fk_fields": [ + { + "default": None, + "description": None, + "docstring": None, + "field_type": "ForeignKeyFieldInstance", + "generated": False, + "indexed": False, + "name": "model", + "nullable": True, + "on_delete": "CASCADE", + "python_type": "models.UUIDPkModel", + "raw_field": "model_id", + "unique": False, + "constraints": {}, + "db_constraint": True, + } + ], + "m2m_fields": [], + "name": "models.UUIDFkRelatedNullModel", + "o2o_fields": [ + { + "default": None, + "description": None, + "docstring": None, + "field_type": "OneToOneFieldInstance", + "generated": False, + "indexed": True, + "name": "parent", + "nullable": True, + "on_delete": "NO ACTION", + "python_type": "models.UUIDPkModel", + "raw_field": "parent_id", + "unique": True, + "constraints": {}, + "db_constraint": True, + } + ], + "pk_field": { + "db_column": "id", + "db_field_types": {"": "CHAR(36)", "postgres": "UUID"}, + "default": "", + "description": None, + "docstring": None, + "field_type": "UUIDField", + "generated": False, + "indexed": True, + "name": "id", + "nullable": False, + "python_type": "uuid.UUID", + "unique": True, + "constraints": {}, + }, + "table": "uuidfkrelatednullmodel", + "unique_together": [], + "indexes": [], + } + + +def test_describe_model_json(): + val = JSONFields.describe() + + assert val == { + "name": "models.JSONFields", + "app": "models", + "table": "jsonfields", + "abstract": False, + "description": "This model contains many JSON blobs", + "docstring": "This model contains many JSON blobs", + "unique_together": [], + "indexes": [], + "pk_field": { + "name": "id", + "field_type": "IntField", + "db_column": "id", + "db_field_types": {"": "INT"}, + "python_type": "int", + "generated": True, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": None, + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + }, + "data_fields": [ + { + "name": "data", + "field_type": "JSONField", + "db_column": "data", + "db_field_types": { + "": "JSON", + "mssql": "NVARCHAR(MAX)", + "oracle": "NCLOB", + "postgres": "JSONB", + }, + "python_type": UNION_DICT_LIST, + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + }, + { + "name": "data_null", + "field_type": "JSONField", + "db_column": "data_null", + "db_field_types": { + "": "JSON", + "mssql": "NVARCHAR(MAX)", + "oracle": "NCLOB", + "postgres": "JSONB", + }, + "python_type": UNION_DICT_LIST, + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + }, + { + "name": "data_default", + "field_type": "JSONField", + "db_column": "data_default", + "db_field_types": { + "": "JSON", + "mssql": "NVARCHAR(MAX)", + "oracle": "NCLOB", + "postgres": "JSONB", + }, + "python_type": UNION_DICT_LIST, + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": "{'a': 1}", + "description": None, + "docstring": None, + "constraints": {}, + }, + { + "name": "data_validate", + "field_type": "JSONField", + "db_column": "data_validate", + "db_field_types": { + "": "JSON", + "mssql": "NVARCHAR(MAX)", + "oracle": "NCLOB", + "postgres": "JSONB", + }, + "python_type": UNION_DICT_LIST, + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + }, + { + "name": "data_pydantic", + "field_type": "JSONField", + "db_column": "data_pydantic", + "db_field_types": { + "": "JSON", + "mssql": "NVARCHAR(MAX)", + "oracle": "NCLOB", + "postgres": "JSONB", + }, + "python_type": "tests.testmodels.TestSchemaForJSONField", + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": "foo=1 bar='baz'", + "description": None, + "docstring": None, + "constraints": {}, + }, + ], + "fk_fields": [], + "backward_fk_fields": [], + "o2o_fields": [], + "backward_o2o_fields": [], + "m2m_fields": [], + } + + +def test_describe_model_json_native(): + val = JSONFields.describe(serializable=False) + + assert val == { + "name": "models.JSONFields", + "app": "models", + "table": "jsonfields", + "abstract": False, + "description": "This model contains many JSON blobs", + "docstring": "This model contains many JSON blobs", + "unique_together": [], + "indexes": [], + "pk_field": { + "name": "id", + "field_type": fields.IntField, + "db_column": "id", + "db_field_types": {"": "INT"}, + "python_type": int, + "generated": True, + "nullable": False, + "unique": True, + "indexed": True, + "default": None, + "description": None, + "docstring": None, + "constraints": {"ge": -2147483648, "le": 2147483647}, + }, + "data_fields": [ + { + "name": "data", + "field_type": fields.JSONField, + "db_column": "data", + "db_field_types": { + "": "JSON", + "mssql": "NVARCHAR(MAX)", + "oracle": "NCLOB", + "postgres": "JSONB", + }, + "python_type": dict | list, + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + }, + { + "name": "data_null", + "field_type": fields.JSONField, + "db_column": "data_null", + "db_field_types": { + "": "JSON", + "mssql": "NVARCHAR(MAX)", + "oracle": "NCLOB", + "postgres": "JSONB", + }, + "python_type": dict | list, + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + }, + { + "name": "data_default", + "field_type": fields.JSONField, + "db_column": "data_default", + "db_field_types": { + "": "JSON", + "mssql": "NVARCHAR(MAX)", + "oracle": "NCLOB", + "postgres": "JSONB", + }, + "python_type": dict | list, + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": {"a": 1}, + "description": None, + "docstring": None, + "constraints": {}, + }, + { + "name": "data_validate", + "field_type": fields.JSONField, + "db_column": "data_validate", + "db_field_types": { + "": "JSON", + "mssql": "NVARCHAR(MAX)", + "oracle": "NCLOB", + "postgres": "JSONB", + }, + "python_type": dict | list, + "generated": False, + "nullable": True, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + }, + { + "name": "data_pydantic", + "field_type": fields.JSONField, + "db_column": "data_pydantic", + "db_field_types": { + "": "JSON", + "mssql": "NVARCHAR(MAX)", + "oracle": "NCLOB", + "postgres": "JSONB", + }, + "python_type": TestSchemaForJSONField, + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": json_pydantic_default, + "description": None, + "docstring": None, + "constraints": {}, + }, + ], + "fk_fields": [], + "backward_fk_fields": [], + "o2o_fields": [], + "backward_o2o_fields": [], + "m2m_fields": [], + } + + +def test_describe_indexes_serializable(): + val = ModelWithIndexes.describe() + + assert val["indexes"] == [ + {"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""}, + { + "fields": ["f3"], + "expressions": [], + "name": "model_with_indexes__f3", + "type": "", + "extra": "", + }, + ] + - def test_describe_indexes_not_serializable(self): - val = ModelWithIndexes.describe(serializable=False) +def test_describe_indexes_not_serializable(): + val = ModelWithIndexes.describe(serializable=False) - self.assertEqual( - val["indexes"], - ModelWithIndexes._meta.indexes, - ) + assert val["indexes"] == ModelWithIndexes._meta.indexes diff --git a/tests/utils/test_run_async.py b/tests/utils/test_run_async.py index 88f202458..0b4044ea9 100644 --- a/tests/utils/test_run_async.py +++ b/tests/utils/test_run_async.py @@ -1,40 +1,78 @@ +"""Tests for run_async function. + +These tests verify that run_async properly cleans up Tortoise state after execution. +Since run_async is designed to run a coroutine and clean up all Tortoise state afterwards, +we need to verify this cleanup behavior works correctly. + +Note: These tests require no active TortoiseContext when they start. If run after tests +that use the session-scoped `db` fixture, they will be skipped. +""" + import os -from unittest import skipIf - -from tortoise import Tortoise, connections, run_async -from tortoise.contrib.test import SimpleTestCase - - -@skipIf(os.name == "nt", "stuck with Windows") -class TestRunAsync(SimpleTestCase): - def setUp(self): - self.somevalue = 1 - - def tearDown(self): - run_async(self.asyncTearDown()) - - async def init(self): - await Tortoise.init(db_url="sqlite://:memory:", modules={"models": []}) - self.somevalue = 2 - self.assertNotEqual(connections._get_storage(), {}) - - async def init_raise(self): - await Tortoise.init(db_url="sqlite://:memory:", modules={"models": []}) - self.somevalue = 3 - self.assertNotEqual(connections._get_storage(), {}) - raise Exception("Some exception") - - def test_run_async(self): - self.assertEqual(connections._get_storage(), {}) - self.assertEqual(self.somevalue, 1) - run_async(self.init()) - self.assertEqual(connections._get_storage(), {}) - self.assertEqual(self.somevalue, 2) - - def test_run_async_raised(self): - self.assertEqual(connections._get_storage(), {}) - self.assertEqual(self.somevalue, 1) - with self.assertRaises(Exception): - run_async(self.init_raise()) - self.assertEqual(connections._get_storage(), {}) - self.assertEqual(self.somevalue, 3) + +import pytest + +from tortoise import Tortoise, run_async +from tortoise.context import get_current_context + + +@pytest.fixture +def holder(): + return {"value": 1} + + +@pytest.fixture +def requires_no_context(): + """Skip the test if there's already an active TortoiseContext.""" + if get_current_context() is not None: + pytest.skip("Test requires no active TortoiseContext - run in isolation") + + +async def init_and_check(holder): + """Initialize Tortoise and verify context is set up.""" + await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["tests.testmodels"]}) + holder["value"] = 2 + # Verify we have an active context + ctx = get_current_context() + assert ctx is not None + assert ctx.connections._get_storage() != {} + + +async def init_and_raise(holder): + """Initialize Tortoise and raise an exception.""" + await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["tests.testmodels"]}) + holder["value"] = 3 + # Verify we have an active context + ctx = get_current_context() + assert ctx is not None + assert ctx.connections._get_storage() != {} + raise Exception("Some exception") + + +@pytest.mark.skipif(os.name == "nt", reason="stuck with Windows") +def test_run_async(holder, requires_no_context): + """Test that run_async properly cleans up after successful execution.""" + # No context should be active before run_async + assert get_current_context() is None + assert holder["value"] == 1 + + run_async(init_and_check(holder)) + + # After run_async, context should be cleaned up + assert get_current_context() is None + assert holder["value"] == 2 + + +@pytest.mark.skipif(os.name == "nt", reason="stuck with Windows") +def test_run_async_raised(holder, requires_no_context): + """Test that run_async properly cleans up even when an exception is raised.""" + # No context should be active before run_async + assert get_current_context() is None + assert holder["value"] == 1 + + with pytest.raises(Exception, match="Some exception"): + run_async(init_and_raise(holder)) + + # After run_async (even with exception), context should be cleaned up + assert get_current_context() is None + assert holder["value"] == 3 diff --git a/tortoise/__init__.py b/tortoise/__init__.py index b3f9ce41d..cacb04941 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -7,15 +7,15 @@ import warnings from collections.abc import Callable, Coroutine, Iterable from types import ModuleType -from typing import Any +from typing import TYPE_CHECKING, Any from anyio import from_thread from tortoise.apps import Apps from tortoise.backends.base.client import BaseDBAsyncClient -from tortoise.backends.base.config_generator import expand_db_url, generate_config +from tortoise.backends.base.config_generator import expand_db_url from tortoise.config import TortoiseConfig -from tortoise.connection import connections +from tortoise.connection import connections, get_connection, get_connections from tortoise.exceptions import ConfigurationError from tortoise.fields.relational import ( BackwardFKRelation, @@ -29,11 +29,90 @@ from tortoise.timezone import _reset_timezone_cache from tortoise.utils import generate_schema_for_client +if TYPE_CHECKING: + from tortoise.context import TortoiseContext + + +class classproperty: + """ + Descriptor that acts like @property but works on classes. + + This allows `Tortoise.apps` and `Tortoise._inited` to dynamically + resolve to the current context's state without using a metaclass. + + Note: This only supports getters, not setters. Internal code must + work with context directly for mutations. + + WARNING: Class-level assignment (Tortoise.apps = value) will SHADOW + this descriptor. Python's descriptor protocol only intercepts instance-level + assignment. Use TortoiseContext directly instead. + """ + + def __init__(self, func: Callable[..., Any]) -> None: + self.func = func + + def __get__(self, obj: Any, objtype: type | None = None) -> Any: + return self.func(objtype) + class Tortoise: - apps: Apps | None = None + """ + Tortoise ORM main interface. + + Provides static methods for initialization and access to ORM state. + All state is managed by TortoiseContext instances. + + NOTE: No class-level state except table_name_generator for backward compat. + All runtime state lives in TortoiseContext. + """ + + # Class-level for backward compatibility; also stored in TortoiseContext table_name_generator: Callable[[type[Model]], str] | None = None - _inited: bool = False + + @classmethod + def _get_context(cls) -> TortoiseContext | None: + """Get the current context from context var.""" + from tortoise.context import get_current_context + + return get_current_context() + + @classmethod + def _require_context(cls) -> TortoiseContext: + """Get the current context, raising if none exists.""" + ctx = cls._get_context() + if ctx is None: + raise ConfigurationError( + "Tortoise ORM is not initialized. Call Tortoise.init() first " + "or use 'async with TortoiseContext()' for explicit context management." + ) + return ctx + + # BACKWARD COMPATIBLE: Class properties (no metaclass needed!) + @classproperty + def apps(cls) -> Apps | None: + """ + Get the Apps registry from current context. + + Returns None if no context is active. + """ + ctx = cls._get_context() + return ctx.apps if ctx else None + + @classproperty + def _inited(cls) -> bool: + """ + Check if Tortoise is initialized. + + Returns False if no context is active. + """ + ctx = cls._get_context() + return ctx.inited if ctx else False + + @classmethod + def is_inited(cls) -> bool: + """Check if Tortoise is initialized.""" + ctx = cls._get_context() + return ctx.inited if ctx else False @classmethod def get_connection(cls, connection_name: str) -> BaseDBAsyncClient: @@ -44,9 +123,9 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient: .. warning:: This is deprecated and will be removed in a future release. Please use - :meth:`connections.get` instead. + :meth:`get_connection` instead. """ - return connections.get(connection_name) + return get_connection(connection_name) @classmethod def describe_model( @@ -150,18 +229,29 @@ def init_app( :param model_paths: Models paths to initialize :param _init_relations: Whether to init relations or not """ - if not cls.apps: - cls.apps = Apps({}, connections, cls.table_name_generator) - cls.apps._table_name_generator = cls.table_name_generator - return cls.apps.init_app(label, model_paths, _init_relations=_init_relations) + from tortoise.context import TortoiseContext, get_current_context + + # Get or create context + ctx = get_current_context() + if ctx is None: + ctx = TortoiseContext() + ctx.__enter__() + + # Create Apps if not exists + if ctx._apps is None: + ctx._apps = Apps({}, ctx.connections, cls.table_name_generator) + ctx._apps._table_name_generator = cls.table_name_generator + return ctx._apps.init_app(label, model_paths, _init_relations=_init_relations) @classmethod def _init_apps( cls, apps_config: dict[str, dict[str, Any]], *, validate_connections: bool = True ) -> None: - cls.apps = Apps( + """Internal: Initialize Apps registry on current context.""" + ctx = cls._require_context() + ctx._apps = Apps( apps_config, - connections, + ctx.connections, cls.table_name_generator, validate_connections=validate_connections, ) @@ -201,7 +291,8 @@ async def init( routers: list[str | type] | None = None, table_name_generator: Callable[[type[Model]], str] | None = None, init_connections: bool = True, - ) -> None: + _enable_global_fallback: bool = False, + ) -> TortoiseContext: """ Sets up Tortoise-ORM: loads apps and models, configures database connections but does not connect to the database yet. The actual connection or connection pool is established @@ -269,53 +360,49 @@ async def init( :param init_connections: When ``False``, skips initializing connection clients while still loading apps and validating connection names against the config. + :param _enable_global_fallback: + When ``True``, stores the context as a global fallback for cross-task access. + This is used by RegisterTortoise (FastAPI) where asgi-lifespan runs lifespan + in a background task. Default is ``False`` for pure context isolation. :raises ConfigurationError: For any configuration error + + :returns: The TortoiseContext that was initialized. For multiple asyncio.run() + calls, capture this and use 'with ctx:' to maintain context. """ - if cls._inited: - await connections.close_all(discard=True) + from tortoise.context import TortoiseContext, _current_context + + # Get or create context - only use contextvar, not global fallback. + # Global fallback is for reading (queries), not for initialization. + # This allows multiple apps to initialize independently even if one + # has global fallback enabled. + ctx = _current_context.get() + if ctx is None: + ctx = TortoiseContext() + ctx.__enter__() + elif ctx.inited: + # Re-initializing existing context + await ctx.close_connections() + + # Validate config source - must provide exactly one if int(bool(config) + bool(config_file) + bool(db_url)) != 1: raise ConfigurationError( 'You should init either from "config", "config_file" or "db_url"' ) + # Normalize config: handle config_file case + normalized_config: dict[str, Any] | TortoiseConfig | None = config if config_file: - config = cls._get_config_from_config_file(config_file) - elif db_url: - if not modules: - raise ConfigurationError('You must specify "db_url" and "modules" together') - config = generate_config(db_url, modules) - elif config is None: - raise ConfigurationError('You must specify "config" or "config_file" or "db_url"') - elif isinstance(config, TortoiseConfig): - config = config.to_dict() - else: - try: - TortoiseConfig.from_dict(config) - except ConfigurationError as exc: - warnings.warn( - f"Config validation warning: {exc}", - RuntimeWarning, - stacklevel=2, - ) - - try: - connections_config = config["connections"] - except KeyError: - raise ConfigurationError('Config must define "connections" section') - - try: - apps_config = config["apps"] - except KeyError: - raise ConfigurationError('Config must define "apps" section') + normalized_config = cls._get_config_from_config_file(config_file) - use_tz = config.get("use_tz", use_tz) - timezone = config.get("timezone", timezone) - routers = config.get("routers", routers) - - cls.table_name_generator = table_name_generator - - if logger.isEnabledFor(logging.DEBUG): + # Debug logging + if logger.isEnabledFor(logging.DEBUG) and normalized_config is not None: + if isinstance(normalized_config, TortoiseConfig): + config_dict = normalized_config.to_dict() + else: + config_dict = normalized_config + connections_config = config_dict.get("connections", {}) + apps_config = config_dict.get("apps", {}) str_connection_config = cls.star_password(connections_config) logger.debug( "Tortoise-ORM startup\n connections: %s\n apps: %s", @@ -323,17 +410,24 @@ async def init( str(apps_config), ) - cls._init_timezone(use_tz, timezone) - if not init_connections and _create_db: - raise ConfigurationError("init_connections=False cannot be used with _create_db=True") - if init_connections: - await connections._init(connections_config, _create_db) - else: - connections._init_config(connections_config) - cls._init_apps(apps_config, validate_connections=init_connections) - cls._init_routers(routers) + # Store table_name_generator at class level for backward compatibility + cls.table_name_generator = table_name_generator - cls._inited = True + # Delegate to context init + await ctx.init( + config=normalized_config, + db_url=db_url, + modules=modules, + _create_db=_create_db, + use_tz=use_tz, + timezone=timezone, + routers=routers, + table_name_generator=table_name_generator, + init_connections=init_connections, + _enable_global_fallback=_enable_global_fallback, + ) + + return ctx @staticmethod def star_password(connections_config) -> str: @@ -381,24 +475,22 @@ async def close_connections(cls) -> None: It is required for this to be called on exit, else your event loop may never complete as it is waiting for the connections to die. - - .. warning:: - This is deprecated and will be removed in a future release. Please use - :meth:`connections.close_all` instead. """ - await connections.close_all() + await get_connections().close_all() logger.info("Tortoise-ORM shutdown") @classmethod async def _reset_apps(cls) -> None: - if not cls.apps: + """Internal: Reset Apps registry on current context.""" + ctx = cls._get_context() + if ctx is None or ctx._apps is None: return - for model in cls.apps.get_models_iterable(): + for model in ctx._apps.get_models_iterable(): if isinstance(model, ModelMeta): model._meta.default_connection = None - cls.apps.clear() - cls.apps = None + ctx._apps.clear() + ctx._apps = None @classmethod async def generate_schemas(cls, safe: bool = True) -> None: @@ -413,7 +505,7 @@ async def generate_schemas(cls, safe: bool = True) -> None: """ if not cls._inited: raise ConfigurationError("You have to call .init() first before generating schemas") - for connection in connections.all(): + for connection in get_connections().all(): await generate_schema_for_client(connection, safe) @classmethod @@ -428,10 +520,11 @@ async def _drop_databases(cls) -> None: raise ConfigurationError("You have to call .init() first before deleting schemas") # this closes any existing connections/pool if any and clears # the storage - await connections.close_all(discard=False) - for conn in connections.all(): + conn_handler = get_connections() + await conn_handler.close_all(discard=False) + for conn in conn_handler.all(): await conn.db_delete() - connections.discard(conn.connection_name) + conn_handler.discard(conn.connection_name) await cls._reset_apps() @@ -461,12 +554,15 @@ async def do_stuff(): run_async(do_stuff()) """ + from tortoise.context import get_current_context async def main() -> None: try: await coro finally: - await connections.close_all(discard=True) + ctx = get_current_context() + if ctx is not None: + await ctx.connections.close_all(discard=True) with from_thread.start_blocking_portal() as portal: portal.call(main) @@ -474,7 +570,6 @@ async def main() -> None: __version__ = "0.25.3" - __all__ = [ "BackwardFKRelation", "BackwardOneToOneRelation", diff --git a/tortoise/apps.py b/tortoise/apps.py index de386b321..ad8ce805e 100644 --- a/tortoise/apps.py +++ b/tortoise/apps.py @@ -86,6 +86,8 @@ def init_app( return self.apps[label] def _load_from_config(self) -> None: + if self._connections is None: + raise ConfigurationError("ConnectionHandler is required to load from config") for name, info in self._config.items(): default_connection = info.get("default_connection", "default") if self._validate_connections: diff --git a/tortoise/backends/base/client.py b/tortoise/backends/base/client.py index a21d8550d..2c2304e8d 100644 --- a/tortoise/backends/base/client.py +++ b/tortoise/backends/base/client.py @@ -9,7 +9,7 @@ from tortoise.backends.base.executor import BaseExecutor from tortoise.backends.base.schema_generator import BaseSchemaGenerator -from tortoise.connection import connections +from tortoise.connection import get_connections from tortoise.exceptions import TransactionManagementError from tortoise.log import db_client_logger @@ -316,10 +316,12 @@ async def ensure_connection(self) -> None: async def __aenter__(self) -> TransactionalDBClient: await self.ensure_connection() - # Set the context variable so the current task is always seeing a - # TransactionWrapper conneciton. - self.token = connections.set(self.connection_name, self.client) + # Acquire connection first to avoid race condition where concurrent tasks + # see the wrapper via the context before it has a connection. self.client._connection = await self.client._parent._pool.acquire() + # Set the context variable so the current task is always seeing a + # TransactionWrapper connection. + self.token = get_connections().set(self.connection_name, self.client) await self.client.begin() return self.client @@ -335,7 +337,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: finally: if self.client._parent._pool: await self.client._parent._pool.release(self.client._connection) - connections.reset(self.token) + get_connections().reset(self.token) class NestedTransactionContext(TransactionContext): diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index 993ff206c..e6d735312 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -485,11 +485,20 @@ def _get_table_sql(self, model: type[Model], safe: bool = True) -> dict: def _get_models_to_create(self) -> list[type[Model]]: from tortoise import Tortoise + from tortoise.context import get_current_context + + # Check for active TortoiseContext first + ctx = get_current_context() + if ctx is not None and ctx._inited and ctx.apps is not None: + apps = ctx.apps + else: + # Fall back to global Tortoise.apps + apps = Tortoise.apps models_to_create: list[type[Model]] = [] - if not Tortoise.apps: + if not apps: return models_to_create - for model in Tortoise.apps.get_models_iterable(): + for model in apps.get_models_iterable(): if model._meta.db == self.client: model._check() models_to_create.append(model) diff --git a/tortoise/backends/sqlite/client.py b/tortoise/backends/sqlite/client.py index 175356057..290b9e6e9 100644 --- a/tortoise/backends/sqlite/client.py +++ b/tortoise/backends/sqlite/client.py @@ -22,7 +22,7 @@ ) from tortoise.backends.sqlite.executor import SqliteExecutor from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator -from tortoise.connection import connections +from tortoise.connection import get_connections from tortoise.contrib.sqlite.regex import ( install_regexp_functions as install_regexp_functions_to_db, ) @@ -191,7 +191,7 @@ async def ensure_connection(self) -> None: async def __aenter__(self) -> T_conn: await self._trxlock.acquire() await self.ensure_connection() - self.token = connections.set(self.connection_name, self.connection) + self.token = get_connections().set(self.connection_name, self.connection) await self.connection.begin() return self.connection @@ -205,7 +205,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: else: await self.connection.commit() finally: - connections.reset(self.token) + get_connections().reset(self.token) self._trxlock.release() diff --git a/tortoise/cli/cli.py b/tortoise/cli/cli.py index 067944b4e..92e1d357e 100644 --- a/tortoise/cli/cli.py +++ b/tortoise/cli/cli.py @@ -13,8 +13,10 @@ from ptpython.repl import embed -from tortoise import Tortoise, __version__, connections +from tortoise import Tortoise, __version__ from tortoise.cli import utils +from tortoise.connection import get_connection +from tortoise.context import TortoiseContext from tortoise.migrations.api import migrate as migrate_api from tortoise.migrations.autodetector import MigrationAutodetector from tortoise.migrations.executor import PlanStep @@ -40,12 +42,10 @@ def do_nothing(*_args, **_kwargs) -> None: @contextlib.asynccontextmanager -async def aclose_tortoise() -> AsyncGenerator[None]: - try: - yield - finally: - if Tortoise._inited: - await connections.close_all() +async def tortoise_cli_context(config: dict[str, Any]) -> AsyncGenerator[TortoiseContext, None]: + async with TortoiseContext() as ctx: + await ctx.init(config=config) + yield ctx class _NoopRecorder(MigrationRecorder): @@ -245,8 +245,7 @@ async def init(ctx: CLIContext, app_labels: tuple[str, ...]) -> None: async def shell(ctx: CLIContext) -> None: config = _normalized_config(_load_config(ctx)) - async with aclose_tortoise(): - await Tortoise.init(config=config) + async with tortoise_cli_context(config): with contextlib.suppress(EOFError, ValueError): await embed( globals=globals(), @@ -269,11 +268,10 @@ async def makemigrations( app_config["migrations"] = migrations_module config["apps"] = apps_config - async with aclose_tortoise(): - await Tortoise.init(config=config) - if not Tortoise.apps: + async with tortoise_cli_context(config) as ctx: + if not ctx.apps: raise utils.CLIError("Tortoise apps are not initialized") - autodetector = MigrationAutodetector(Tortoise.apps, apps_config) + autodetector = MigrationAutodetector(ctx.apps, apps_config) if empty: await autodetector.loader.build_graph() old_state = await autodetector._project_state() @@ -340,7 +338,7 @@ async def _run_migrate( raise utils.CLIUsageError("MIGRATION requires APP_LABEL") target = f"{app_label}.{migration}" - async with aclose_tortoise(): + async with tortoise_cli_context(config): await migrate_api( config=config, app_labels=None, @@ -409,10 +407,9 @@ async def history(ctx: CLIContext, app_labels: tuple[str, ...]) -> None: config["apps"] = apps_config apps_by_connection = _group_apps_by_connection(apps_config) - async with aclose_tortoise(): - await Tortoise.init(config=config) + async with tortoise_cli_context(config): for connection_name, subset in apps_by_connection.items(): - recorder = MigrationRecorder(connections.get(connection_name)) + recorder = MigrationRecorder(get_connection(connection_name)) applied = await recorder.applied_migrations() _emit_history(applied, connection_name, subset) diff --git a/tortoise/connection.py b/tortoise/connection.py index ca841e7fb..a9fd5af50 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -3,8 +3,7 @@ import asyncio import contextvars import importlib -from contextvars import ContextVar -from copy import copy +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from tortoise.backends.base.config_generator import expand_db_url @@ -16,15 +15,67 @@ DBConfigType = dict[str, Any] +@dataclass(slots=True) +class ConnectionToken: + """ + Token for resetting connection storage modifications. + + Used by transactions to temporarily replace a connection with a transaction client, + then restore the original connection when the transaction completes. + """ + + _handler: ConnectionHandler + _alias: str + _old_value: BaseDBAsyncClient | None + _cv_token: contextvars.Token | None = field(default=None) + _used: bool = field(default=False) + + class ConnectionHandler: - _conn_storage: ContextVar[dict[str, BaseDBAsyncClient]] = contextvars.ContextVar( - "_conn_storage", default={} - ) + """ + Connection management for a single TortoiseContext. + + Each TortoiseContext owns its own ConnectionHandler instance with isolated storage. + """ def __init__(self) -> None: - """Unified connection management interface.""" + """Initialize connection handler with empty storage.""" self._db_config: DBConfigType | None = None self._create_db: bool = False + # Use ContextVar for task isolation within this handler instance. + # This ensures transactions (which use .set()) are isolated to the task. + self._storage_var: contextvars.ContextVar[dict[str, BaseDBAsyncClient]] = ( + contextvars.ContextVar(f"storage_{id(self)}", default={}) + ) + + @property + def _storage(self) -> dict[str, BaseDBAsyncClient]: + """ + Internal storage for connections. + We use a property to provide a dict-like interface while being backed by a ContextVar. + """ + return self._get_storage() + + @_storage.setter + def _storage(self, value: dict[str, BaseDBAsyncClient]) -> None: + """Allow direct assignment to storage for legacy compatibility (and tests).""" + self._storage_var.set(value) + + def _get_storage(self) -> dict[str, BaseDBAsyncClient]: + """Get the connection storage dict for the current task context.""" + return self._storage_var.get() + + def _set_storage(self, new_storage: dict[str, BaseDBAsyncClient]) -> None: + """Set the connection storage dict. Used for testing purposes.""" + self._storage = new_storage + + def _copy_storage(self) -> dict[str, BaseDBAsyncClient]: + """Return a shallow copy of the storage.""" + return dict(self._get_storage()) + + def _clear_storage(self) -> None: + """Clear all connections from storage in the current context.""" + self._storage_var.set({}) async def _init(self, db_config: DBConfigType, create_db: bool) -> None: if self._db_config is None: @@ -61,19 +112,6 @@ def db_config(self) -> DBConfigType: ) return self._db_config - def _get_storage(self) -> dict[str, BaseDBAsyncClient]: - return self._conn_storage.get() - - def _set_storage(self, new_storage: dict[str, BaseDBAsyncClient]) -> contextvars.Token: - # Should be used only for testing purposes. - return self._conn_storage.set(new_storage) - - def _copy_storage(self) -> dict[str, BaseDBAsyncClient]: - return copy(self._get_storage()) - - def _clear_storage(self) -> None: - self._get_storage().clear() - def _discover_client_class(self, db_info: dict) -> type[BaseDBAsyncClient]: # Let exception bubble up for transparency engine_str = db_info.get("engine", "") @@ -126,7 +164,7 @@ def get(self, conn_alias: str) -> BaseDBAsyncClient: :raises ConfigurationError: If the connection alias does not exist. """ - storage: dict[str, BaseDBAsyncClient] = self._get_storage() + storage = self._get_storage() try: return storage[conn_alias] except KeyError: @@ -134,22 +172,27 @@ def get(self, conn_alias: str) -> BaseDBAsyncClient: storage[conn_alias] = connection return connection - def set(self, conn_alias: str, conn_obj: BaseDBAsyncClient) -> contextvars.Token: + def set(self, conn_alias: str, conn_obj: BaseDBAsyncClient) -> ConnectionToken: """ - Sets the given alias to the provided connection object. + Sets the given alias to the provided connection object for the current task. :param conn_alias: The alias to set the connection for. :param conn_obj: The connection object that needs to be set for this alias. + :returns: A token that can be used to restore the previous context via reset(). + .. note:: - This method copies the storage from the `current context`, updates the - ``conn_alias`` with the provided ``conn_obj`` and sets the updated storage - in a `new context` and therefore returns a ``contextvars.Token`` in order to restore - the original context storage. + This method is primarily used by transactions to temporarily replace a connection + with a transaction client. Call reset() with the returned token to restore the + original connection when the transaction completes. """ + old_value = self._get_storage().get(conn_alias) storage_copy = self._copy_storage() storage_copy[conn_alias] = conn_obj - return self._conn_storage.set(storage_copy) + cv_token = self._storage_var.set(storage_copy) + return ConnectionToken( + _handler=self, _alias=conn_alias, _old_value=old_value, _cv_token=cv_token + ) def discard(self, conn_alias: str) -> BaseDBAsyncClient | None: """ @@ -163,25 +206,32 @@ def discard(self, conn_alias: str) -> BaseDBAsyncClient | None: """ return self._get_storage().pop(conn_alias, None) - def reset(self, token: contextvars.Token) -> None: + def reset(self, token: ConnectionToken | None) -> None: """ - Reset the underlying storage to the previous context state. + Reset the connection storage to the previous context state. - Resets the storage state to the `context` associated with the provided token. After - resetting storage state, any additional `connections` created in the `old context` are - copied into the `current context`. + Restores the connection state for all aliases to what it was before the set() call. :param token: - The token corresponding to the `context` to which the storage state has to - be reset. Typically, this token is obtained by calling the - :meth:`set` method of this class. + The token returned by the set() method. Can be None (no-op). """ - current_storage = self._get_storage() - self._conn_storage.reset(token) - prev_storage = self._get_storage() - for alias, conn in current_storage.items(): - if alias not in prev_storage: - prev_storage[alias] = conn + if token is None: + return + + if token._used: + raise ValueError("Token has already been used") + token._used = True + + if token._cv_token and isinstance(token._cv_token, contextvars.Token): + self._storage_var.reset(token._cv_token) + else: + # Fallback when no ContextVar token (e.g., mock tokens in tests) + storage = self._copy_storage() + if token._old_value is None: + storage.pop(token._alias, None) + else: + storage[token._alias] = token._old_value + self._storage = storage def all(self) -> list[BaseDBAsyncClient]: """Returns a list of connection objects from the storage in the `current context`.""" @@ -200,6 +250,9 @@ async def close_all(self, discard: bool = True) -> None: :param discard: If ``False``, all connection objects are closed but `retained` in the storage. """ + # Handle case where connections were never initialized (e.g., init failed) + if self._db_config is None: + return tasks = [conn.close() for conn in self.all()] await asyncio.gather(*tasks) if discard: @@ -207,4 +260,62 @@ async def close_all(self, discard: bool = True) -> None: self.discard(alias) -connections = ConnectionHandler() +class _ConnectionsProxy: + """ + Simple delegator that forwards all operations to the current context's ConnectionHandler. + + This provides backward compatibility for code using the `connections` module-level singleton. + All operations require an active TortoiseContext - if no context is active, a clear error is raised. + + .. deprecated:: + Direct use of `connections` is deprecated. Use `get_connection()` or `get_connections()` instead, + or access connections through the context: `ctx.connections`. + """ + + def _get_handler(self) -> ConnectionHandler: + """Get the ConnectionHandler from the current context.""" + from tortoise.context import require_context + + return require_context().connections + + def __getattr__(self, name: str): + """Delegate attribute access to the current context's ConnectionHandler.""" + return getattr(self._get_handler(), name) + + # Properties must be explicit since __getattr__ doesn't intercept descriptor access + @property + def db_config(self) -> DBConfigType: + """Return the DB config.""" + return self._get_handler().db_config + + +connections = _ConnectionsProxy() + + +def get_connection(alias: str) -> BaseDBAsyncClient: + """ + Get a database connection by alias from the current context. + + This is a convenience function. Prefer accessing connections directly + via context: `ctx.connections.get(alias)` + + :param alias: The connection alias (e.g., "default") + :raises ConfigurationError: If no context is active or connection not found + """ + from tortoise.context import require_context + + return require_context().connections.get(alias) + + +def get_connections() -> ConnectionHandler: + """ + Get the ConnectionHandler from the current context. + + This is a convenience function. Prefer accessing connections directly + via context: `ctx.connections` + + :raises ConfigurationError: If no context is active + """ + from tortoise.context import require_context + + return require_context().connections diff --git a/tortoise/context.py b/tortoise/context.py new file mode 100644 index 000000000..9763cf12c --- /dev/null +++ b/tortoise/context.py @@ -0,0 +1,615 @@ +""" +Context-based state management for Tortoise ORM. + +This module provides the TortoiseContext class which encapsulates all Tortoise ORM state +(connections, apps, init status, timezone, routers) into a single context object. This enables: + +- Parallel test execution (each worker gets its own context) +- Event loop isolation (connections bound to context's loop) +- Clean teardown (context owns all resources) + +Usage: + async with TortoiseContext() as ctx: + await ctx.init(db_url="sqlite://:memory:", modules={"models": ["myapp.models"]}) + await ctx.generate_schemas() + # Models automatically use ctx.connections when context is active + user = await User.create(name="test") +""" + +from __future__ import annotations + +import contextvars +import importlib +import os +from collections.abc import AsyncIterator, Callable +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any + +from tortoise.backends.base.config_generator import generate_config +from tortoise.config import TortoiseConfig +from tortoise.connection import ConnectionHandler +from tortoise.exceptions import ConfigurationError +from tortoise.timezone import _reset_timezone_cache + +if TYPE_CHECKING: + from collections.abc import Iterable + from types import ModuleType + + from tortoise.apps import Apps + from tortoise.backends.base.client import BaseDBAsyncClient + from tortoise.models import Model + + +# ContextVar for tracking the current active context +_current_context: contextvars.ContextVar[TortoiseContext | None] = contextvars.ContextVar( + "tortoise_context", default=None +) + +# Optional global fallback context for cross-task access. +# This is used by RegisterTortoise (FastAPI) where asgi-lifespan runs lifespan +# in a background task, but requests/tests run in a different task. +# Disabled by default; enabled via Tortoise.init(_enable_global_fallback=True). +_global_context: TortoiseContext | None = None + + +def get_current_context() -> TortoiseContext | None: + """ + Get the currently active TortoiseContext, or None if no context is active. + + Checks the contextvar first (for proper isolation), then falls back to + the global context if one was set via _enable_global_fallback. + + Returns: + The current TortoiseContext if one is active, None otherwise. + """ + ctx = _current_context.get() + if ctx is not None: + return ctx + return _global_context + + +def set_global_context(ctx: TortoiseContext) -> None: + """ + Set the global fallback context for cross-task access. + + This is used by RegisterTortoise (FastAPI) where asgi-lifespan runs lifespan + in a background task, but requests/tests run in a different task. + The global context allows these cross-task scenarios to work without + explicit context passing. + + Args: + ctx: The TortoiseContext to set as global fallback. + + Raises: + ConfigurationError: If a global context is already set. Only one global + context can be active at a time. For multiple isolated contexts, + use explicit TortoiseContext() without global fallback. + """ + global _global_context + if _global_context is not None: + raise ConfigurationError( + "Global context fallback is already enabled by another Tortoise.init() call. " + "Only one global context can be active at a time. " + "Use explicit TortoiseContext() for multiple isolated contexts, " + "or set _enable_global_fallback=False for secondary apps." + ) + _global_context = ctx + + +def require_context() -> TortoiseContext: + """ + Get the currently active TortoiseContext, raising if none is active. + + Returns: + The current TortoiseContext. + + Raises: + RuntimeError: If no TortoiseContext is currently active. + """ + ctx = get_current_context() + if ctx is None: + raise RuntimeError( + "No TortoiseContext is currently active. " + "Use 'async with TortoiseContext() as ctx:' to create one, " + "or call Tortoise.init() for global state." + ) + return ctx + + +class TortoiseContext: + """ + Encapsulates all Tortoise ORM state for a single execution context. + + Each TortoiseContext instance owns: + - A ConnectionHandler with database connections + - An Apps registry with model definitions + - Initialization state tracking + + Use cases: + - Isolated test environments (pytest fixtures) + - Parallel test execution with pytest-xdist + - Multiple database configurations in the same process + - Scoped database sessions with automatic cleanup + + The context is tracked via contextvars, allowing async code to + automatically resolve the correct connections without explicit passing. + + Example: + async with TortoiseContext() as ctx: + await ctx.init( + db_url="sqlite://:memory:", + modules={"models": ["myapp.models"]} + ) + await ctx.generate_schemas() + # Models use this context's connections automatically + user = await User.create(name="test") + """ + + def __init__(self) -> None: + """Initialize a new TortoiseContext with empty state.""" + self._connections: ConnectionHandler | None = None + self._apps: Apps | None = None + self._inited: bool = False + self._token: contextvars.Token[TortoiseContext | None] | None = None + self._table_name_generator: Callable[[type[Model]], str] | None = None + self._default_connection: str | None = None + # Timezone settings + self._use_tz: bool = False + self._timezone: str = "UTC" + # Routers + self._routers: list[type] = [] + + @property + def connections(self) -> ConnectionHandler: + """ + Get the ConnectionHandler for this context. + + Creates a new ConnectionHandler on first access (lazy initialization). + The handler uses instance-level storage for true isolation between contexts. + + Returns: + The ConnectionHandler instance owned by this context. + """ + if self._connections is None: + # ConnectionHandler always uses instance storage for isolation + self._connections = ConnectionHandler() + return self._connections + + @property + def apps(self) -> Apps | None: + """ + Get the Apps registry for this context. + + Returns: + The Apps instance if initialized, None otherwise. + """ + return self._apps + + @property + def inited(self) -> bool: + """ + Check if this context has been initialized. + + Returns: + True if init() has been called successfully, False otherwise. + """ + return self._inited + + @property + def default_connection(self) -> str | None: + """ + Get the default connection name for this context. + + Returns: + The default connection name if one is configured, None otherwise. + A default is automatically set when there's only one connection + or when a connection is named "default". + """ + return self._default_connection + + @property + def use_tz(self) -> bool: + """ + Check if timezone-aware datetimes are enabled. + + Returns: + True if datetime fields are timezone-aware, False otherwise. + """ + return self._use_tz + + @property + def timezone(self) -> str: + """ + Get the timezone configured for this context. + + Returns: + The timezone string (e.g., "UTC", "America/New_York"). + """ + return self._timezone + + @property + def routers(self) -> list[type]: + """ + Get the database routers for this context. + + Returns: + List of router classes configured for this context. + """ + return self._routers + + def _get_config_from_config_file(self, config_file: str) -> dict: + """Load configuration from a JSON or YAML file.""" + import json + import os + + _, extension = os.path.splitext(config_file) + if extension in (".yml", ".yaml"): + import yaml # pylint: disable=C0415 + + with open(config_file) as f: + config = yaml.safe_load(f) + elif extension == ".json": + with open(config_file) as f: + config = json.load(f) + else: + raise ConfigurationError( + f"Unknown config extension {extension}, only .yml and .json are supported" + ) + return config + + async def init( + self, + config: dict[str, Any] | TortoiseConfig | None = None, + *, + config_file: str | None = None, + db_url: str | None = None, + modules: dict[str, Iterable[str | ModuleType]] | None = None, + _create_db: bool = False, + use_tz: bool = False, + timezone: str = "UTC", + routers: list[str | type] | None = None, + table_name_generator: Callable[[type[Model]], str] | None = None, + init_connections: bool = True, + _enable_global_fallback: bool = False, + ) -> None: + """ + Initialize this context with database configuration. + + You can configure using one of: ``config``, ``config_file``, or ``(db_url, modules)``. + + This method is self-sufficient and can be used directly in tests without + going through Tortoise.init(): + + async with TortoiseContext() as ctx: + await ctx.init(db_url="sqlite://:memory:", modules={"models": ["myapp.models"]}) + # Run tests... + + Args: + config: Full configuration dict or TortoiseConfig with 'connections' and 'apps' keys. + config_file: Path to .json or .yml file containing configuration. + db_url: Database URL string (e.g., "sqlite://:memory:"). + modules: Dictionary mapping app labels to lists of model modules. + _create_db: If True, creates the database if it doesn't exist. + use_tz: If True, datetime fields will be timezone-aware. + timezone: Timezone to use, defaults to "UTC". + routers: List of database router paths or classes. + table_name_generator: Optional callable to generate table names. + init_connections: If False, skips initializing connection clients while still + loading apps and validating connection names against the config. + _enable_global_fallback: If True, sets this context as the global fallback + for cross-task access (e.g., asgi-lifespan scenarios). Default is False. + + Raises: + ConfigurationError: If configuration is invalid or incomplete. + """ + from tortoise.apps import Apps + + # Handle config_file: load it as config dict + if config_file is not None: + if config is not None: + raise ConfigurationError("Cannot specify both 'config' and 'config_file'") + config = self._get_config_from_config_file(config_file) + + # Convert input to TortoiseConfig for typed access + typed_config: TortoiseConfig + if config is None: + if db_url is None or modules is None: + raise ConfigurationError( + "Must provide either 'config', 'config_file', or both 'db_url' and 'modules'" + ) + config_dict = generate_config(db_url, app_modules=modules) + typed_config = TortoiseConfig.from_dict(config_dict) + elif isinstance(config, TortoiseConfig): + typed_config = config + else: + # Validate and convert dict config to TortoiseConfig + typed_config = TortoiseConfig.from_dict(config) + + # Convert to dict for Apps and ConnectionHandler (they expect dict format) + config_dict = typed_config.to_dict() + connections_config = config_dict["connections"] + apps_config = config_dict["apps"] + + # Use typed config values with fallback to parameters + effective_use_tz = typed_config.use_tz if typed_config.use_tz is not None else use_tz + effective_timezone = ( + typed_config.timezone if typed_config.timezone is not None else timezone + ) + effective_routers = typed_config.routers if typed_config.routers is not None else routers + + self._table_name_generator = table_name_generator + + # Validate init_connections and _create_db combination + if not init_connections and _create_db: + raise ConfigurationError("init_connections=False cannot be used with _create_db=True") + + # Initialize timezone + self._init_timezone(effective_use_tz, effective_timezone) + + # Initialize connections for this context + if init_connections: + await self.connections._init(connections_config, _create_db) + else: + self.connections._init_config(connections_config) + + # Initialize apps for this context + self._apps = Apps( + apps_config, + self.connections, + self._table_name_generator, + validate_connections=init_connections, + ) + + # Initialize routers + self._init_routers(effective_routers) + + # Detect default connection + connection_names = list(typed_config.connections.keys()) + if len(connection_names) == 1: + # Single connection becomes the default automatically + self._default_connection = connection_names[0] + elif "default" in connection_names: + # Connection named "default" is used as default + self._default_connection = "default" + else: + # Multiple connections without a "default" - require explicit specification + self._default_connection = None + + self._inited = True + + # Set global fallback for cross-task access if enabled + if _enable_global_fallback: + set_global_context(self) + + def _init_timezone(self, use_tz: bool, timezone: str) -> None: + """Initialize timezone settings for this context.""" + self._use_tz = use_tz + self._timezone = timezone + # Set environment variables for backward compatibility + os.environ["USE_TZ"] = str(use_tz) + os.environ["TIMEZONE"] = timezone + _reset_timezone_cache() + + def _init_routers(self, routers: list[str | type] | None = None) -> None: + """Initialize database routers for this context.""" + from tortoise.router import router + + routers = routers or [] + router_cls = [] + for r in routers: + if isinstance(r, str): + try: + module_name, class_name = r.rsplit(".", 1) + router_cls.append(getattr(importlib.import_module(module_name), class_name)) + except Exception: + raise ConfigurationError(f"Can't import router from `{r}`") + elif isinstance(r, type): + router_cls.append(r) + else: + raise ConfigurationError("Router must be either str or type") + self._routers = router_cls + router.init_routers(router_cls) + + async def generate_schemas(self, safe: bool = True) -> None: + """ + Generate database schemas for all models in this context. + + Args: + safe: When True, creates tables only if they don't already exist. + + Raises: + ConfigurationError: If context has not been initialized. + """ + from tortoise.utils import generate_schema_for_client + + if not self._inited: + raise ConfigurationError( + "Context not initialized. Call init() before generating schemas." + ) + for connection in self.connections.all(): + await generate_schema_for_client(connection, safe) + + def get_model(self, app_label: str, model_name: str) -> type[Model]: + """ + Retrieve a model by app label and model name. + + Args: + app_label: The app label (e.g., "models"). + model_name: The model class name (e.g., "User"). + + Returns: + The model class. + + Raises: + ConfigurationError: If context not initialized or model not found. + """ + if self._apps is None: + raise ConfigurationError( + "Context not initialized. Call init() before accessing models." + ) + return self._apps.get_model(app_label, model_name) + + def db(self, connection_name: str | None = None) -> BaseDBAsyncClient: + """ + Get a database connection by name. + + Args: + connection_name: The connection alias. If None, uses the default connection. + With a single connection, it becomes the default automatically. + With multiple connections, either specify explicitly or + configure one as "default". + + Returns: + The database client for the specified connection. + + Raises: + ConfigurationError: If context not initialized, connection not found, + or no default connection when multiple exist. + """ + if not self._inited: + raise ConfigurationError( + "Context not initialized. Call init() before accessing database." + ) + + if connection_name is None: + if self._default_connection is None: + raise ConfigurationError( + "No default connection configured. Either use a single connection, " + "name one 'default', or specify connection_name explicitly." + ) + connection_name = self._default_connection + + return self.connections.get(connection_name) + + async def close_connections(self) -> None: + """ + Close all database connections owned by this context. + + This is called automatically when exiting the async context manager. + Also clears the global fallback if this context was set as global. + """ + global _global_context + if self._connections is not None: + # Only close if connections were actually initialized + if self._connections._db_config is not None: + await self._connections.close_all(discard=True) + self._connections = None + # Clear global context if this context was set as the global fallback + if _global_context is self: + _global_context = None + + def __enter__(self) -> TortoiseContext: + """ + Enter the context manager and set this context as current. + + Returns: + This context instance. + """ + self._token = _current_context.set(self) + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """ + Exit the context manager and restore the previous context. + """ + if self._token is not None: + _current_context.reset(self._token) + self._token = None + + async def __aenter__(self) -> TortoiseContext: + """ + Enter the async context manager and set this context as current. + + Returns: + This context instance. + """ + self.__enter__() + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """ + Exit the async context manager, close connections, and restore previous context. + """ + await self.close_connections() + self._apps = None + self._inited = False + self.__exit__(exc_type, exc_val, exc_tb) + + +@asynccontextmanager +async def tortoise_test_context( + modules: list[str], + db_url: str = "sqlite://:memory:", + app_label: str = "models", + *, + connection_label: str | None = None, + use_tz: bool = False, + timezone: str = "UTC", + routers: list[str | type] | None = None, +) -> AsyncIterator[TortoiseContext]: + """ + Async context manager for isolated test database setup. + + This is the recommended way to set up Tortoise ORM for testing with pytest. + Each call creates a completely isolated context with its own: + - ConnectionHandler (no global state pollution) + - Apps registry + - Database (created fresh, cleaned up on exit) + - Timezone and router configuration + + Example with pytest-asyncio: + @pytest_asyncio.fixture + async def db(): + async with tortoise_test_context(["myapp.models"]) as ctx: + yield ctx + + @pytest.mark.asyncio + async def test_create_user(db): + user = await User.create(name="Alice") + assert user.id is not None + + Features: + - Creates isolated TortoiseContext (no global state pollution) + - Creates fresh database and generates schemas + - Cleans up connections on exit + - xdist-safe (each worker gets own context) + + Args: + modules: List of module paths to discover models from. + db_url: Database URL, defaults to in-memory SQLite. + app_label: The app label for the models, defaults to "models". + connection_label: The connection alias name. If None, defaults to "default". + use_tz: If True, datetime fields will be timezone-aware. + timezone: Timezone to use, defaults to "UTC". + routers: List of database router paths or classes. + + Yields: + An initialized TortoiseContext ready for use. + """ + ctx = TortoiseContext() + async with ctx: + # Build config with explicit connection label if provided + config = generate_config( + db_url, + app_modules={app_label: modules}, + connection_label=connection_label, + testing=True, + ) + await ctx.init( + config=config, + _create_db=True, + use_tz=use_tz, + timezone=timezone, + routers=routers, + ) + await ctx.generate_schemas(safe=False) + yield ctx + + +__all__ = [ + "TortoiseContext", + "get_current_context", + "require_context", + "set_global_context", + "tortoise_test_context", +] diff --git a/tortoise/contrib/aiohttp/__init__.py b/tortoise/contrib/aiohttp/__init__.py index 0bd7424df..1af6b15ba 100644 --- a/tortoise/contrib/aiohttp/__init__.py +++ b/tortoise/contrib/aiohttp/__init__.py @@ -5,7 +5,8 @@ from aiohttp import web # pylint: disable=E0401 -from tortoise import Tortoise, connections +from tortoise import Tortoise +from tortoise.connection import get_connections from tortoise.log import logger @@ -81,13 +82,13 @@ def register_tortoise( async def init_orm(app): # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info(f"Tortoise-ORM started, {connections._get_storage()}, {Tortoise.apps}") + logger.info(f"Tortoise-ORM started, {get_connections()._get_storage()}, {Tortoise.apps}") if generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() async def close_orm(app): # pylint: disable=W0612 - await connections.close_all() + await Tortoise.close_connections() logger.info("Tortoise-ORM shutdown") app.on_startup.append(init_orm) diff --git a/tortoise/contrib/blacksheep/__init__.py b/tortoise/contrib/blacksheep/__init__.py index f92c9b740..015ac1907 100644 --- a/tortoise/contrib/blacksheep/__init__.py +++ b/tortoise/contrib/blacksheep/__init__.py @@ -7,7 +7,8 @@ from blacksheep.server import Application from blacksheep.server.responses import json -from tortoise import Tortoise, connections +from tortoise import Tortoise +from tortoise.connection import get_connections from tortoise.exceptions import DoesNotExist, IntegrityError from tortoise.log import logger @@ -89,14 +90,14 @@ def register_tortoise( @app.on_start async def init_orm(context) -> None: # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) + logger.info("Tortoise-ORM started, %s, %s", get_connections()._get_storage(), Tortoise.apps) if generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() @app.on_stop async def close_orm(context) -> None: # pylint: disable=W0612 - await connections.close_all() + await Tortoise.close_connections() logger.info("Tortoise-ORM shutdown") if add_exception_handlers: diff --git a/tortoise/contrib/fastapi/__init__.py b/tortoise/contrib/fastapi/__init__.py index 391fc5649..7325f1b43 100644 --- a/tortoise/contrib/fastapi/__init__.py +++ b/tortoise/contrib/fastapi/__init__.py @@ -7,7 +7,9 @@ from types import ModuleType from typing import TYPE_CHECKING -from tortoise import Tortoise, connections +from tortoise import Tortoise +from tortoise.connection import get_connections +from tortoise.context import TortoiseContext from tortoise.exceptions import DoesNotExist, IntegrityError from tortoise.log import logger @@ -102,6 +104,10 @@ class RegisterTortoise(AbstractAsyncContextManager): A boolean that specifies if datetime will be timezone-aware by default or not. timezone: Timezone to use, default is UTC. + _enable_global_fallback: + If True, enables global context fallback for cross-task access (e.g., when + using asgi-lifespan which runs lifespan in a background task). Default is True. + Set to False when running multiple apps in the same process to avoid conflicts. Raises ------ @@ -121,6 +127,7 @@ def __init__( use_tz: bool = False, timezone: str = "UTC", _create_db: bool = False, + _enable_global_fallback: bool = True, ) -> None: self.app = app self.config = config @@ -131,6 +138,8 @@ def __init__( self.use_tz = use_tz self.timezone = timezone self._create_db = _create_db + self._enable_global_fallback = _enable_global_fallback + self._context: TortoiseContext | None = None if add_exception_handlers and app is not None: from starlette.middleware.exceptions import ExceptionMiddleware @@ -150,8 +159,8 @@ async def wrap_middleware_call(self, *args, **kw) -> None: ExceptionMiddleware.__call__ = wrap_middleware_call # type:ignore - async def init_orm(self) -> None: # pylint: disable=W0612 - await Tortoise.init( + async def init_orm(self) -> TortoiseContext: # pylint: disable=W0612 + self._context = await Tortoise.init( config=self.config, config_file=self.config_file, db_url=self.db_url, @@ -159,15 +168,23 @@ async def init_orm(self) -> None: # pylint: disable=W0612 use_tz=self.use_tz, timezone=self.timezone, _create_db=self._create_db, + _enable_global_fallback=self._enable_global_fallback, ) - logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) + # Store context in app.state for explicit access when global fallback is disabled + if self.app is not None: + self.app.state._tortoise_context = self._context + logger.info("Tortoise-ORM started, %s, %s", get_connections()._get_storage(), Tortoise.apps) if self.generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() - - @staticmethod - async def close_orm() -> None: # pylint: disable=W0612 - await connections.close_all() + return self._context + + async def close_orm(self) -> None: # pylint: disable=W0612 + await Tortoise.close_connections() + # Clear context from app.state + if self.app is not None and hasattr(self.app.state, "_tortoise_context"): + delattr(self.app.state, "_tortoise_context") + self._context = None logger.info("Tortoise-ORM shutdown") def __call__(self, *args, **kwargs) -> Self: diff --git a/tortoise/contrib/quart/__init__.py b/tortoise/contrib/quart/__init__.py index 3a2fbf643..067fda66c 100644 --- a/tortoise/contrib/quart/__init__.py +++ b/tortoise/contrib/quart/__init__.py @@ -7,7 +7,8 @@ from quart import Quart # pylint: disable=E0401 -from tortoise import Tortoise, connections +from tortoise import Tortoise +from tortoise.connection import get_connections from tortoise.log import logger @@ -86,14 +87,14 @@ def register_tortoise( @app.before_serving async def init_orm() -> None: # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) + logger.info("Tortoise-ORM started, %s, %s", get_connections()._get_storage(), Tortoise.apps) if _generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() @app.after_serving async def close_orm() -> None: # pylint: disable=W0612 - await connections.close_all() + await Tortoise.close_connections() logger.info("Tortoise-ORM shutdown") @app.cli.command() # type: ignore @@ -105,7 +106,7 @@ async def inner() -> None: config=config, config_file=config_file, db_url=db_url, modules=modules ) await Tortoise.generate_schemas() - await connections.close_all() + await Tortoise.close_connections() logger.setLevel(logging.DEBUG) loop = asyncio.get_event_loop() diff --git a/tortoise/contrib/sanic/__init__.py b/tortoise/contrib/sanic/__init__.py index d2812e32e..354e40a57 100644 --- a/tortoise/contrib/sanic/__init__.py +++ b/tortoise/contrib/sanic/__init__.py @@ -5,7 +5,8 @@ from sanic import Sanic # pylint: disable=E0401 -from tortoise import Tortoise, connections +from tortoise import Tortoise +from tortoise.connection import get_connections from tortoise.log import logger @@ -81,7 +82,7 @@ def register_tortoise( async def tortoise_init() -> None: await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) # pylint: disable=W0212 + logger.info("Tortoise-ORM started, %s, %s", get_connections()._get_storage(), Tortoise.apps) # pylint: disable=W0212 if generate_schemas: @@ -100,5 +101,5 @@ async def init_orm(app): @app.after_server_stop async def close_orm(app): # pylint: disable=W0612 - await connections.close_all() + await Tortoise.close_connections() logger.info("Tortoise-ORM shutdown") diff --git a/tortoise/contrib/starlette/__init__.py b/tortoise/contrib/starlette/__init__.py index ac7114b84..9cddeadfd 100644 --- a/tortoise/contrib/starlette/__init__.py +++ b/tortoise/contrib/starlette/__init__.py @@ -5,7 +5,8 @@ from starlette.applications import Starlette # pylint: disable=E0401 -from tortoise import Tortoise, connections +from tortoise import Tortoise +from tortoise.connection import get_connections from tortoise.log import logger @@ -82,12 +83,12 @@ def register_tortoise( @app.on_event("startup") async def init_orm() -> None: # pylint: disable=W0612 await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules) - logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps) + logger.info("Tortoise-ORM started, %s, %s", get_connections()._get_storage(), Tortoise.apps) if generate_schemas: logger.info("Tortoise-ORM generating schema") await Tortoise.generate_schemas() @app.on_event("shutdown") async def close_orm() -> None: # pylint: disable=W0612 - await connections.close_all() + await Tortoise.close_connections() logger.info("Tortoise-ORM shutdown") diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index 51e0ace43..1255bf22b 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -1,44 +1,60 @@ +""" +Modern testing utilities for Tortoise ORM. + +Use tortoise_test_context() with pytest fixtures: + + @pytest_asyncio.fixture + async def db(): + async with tortoise_test_context(["myapp.models"]) as ctx: + yield ctx + + @pytest.mark.asyncio + async def test_example(db): + user = await User.create(name="Test") + assert user.id is not None + +For capability-based test skipping: + + @requireCapability(dialect="sqlite") + @pytest.mark.asyncio + async def test_sqlite_only(db): + # This test only runs on SQLite + ... +""" + from __future__ import annotations -import asyncio import inspect -import os as _os import typing -import unittest -from collections.abc import Callable, Coroutine, Iterable +from collections.abc import Callable, Coroutine from functools import partial, wraps -from types import ModuleType -from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast +from typing import ParamSpec, TypeVar, cast from unittest import SkipTest, expectedFailure, skip, skipIf, skipUnless -from tortoise import Model, Tortoise, connections -from tortoise.backends.base.config_generator import generate_config as _generate_config -from tortoise.exceptions import DBConnectionError, OperationalError - -if TYPE_CHECKING: - from asyncio.events import AbstractEventLoop +from tortoise import Tortoise +from tortoise.connection import get_connection +from tortoise.context import TortoiseContext, tortoise_test_context +T = TypeVar("T") +P = ParamSpec("P") +AsyncFunc = Callable[P, Coroutine[None, None, T]] +AsyncFuncDeco = Callable[..., AsyncFunc] +ModulesConfigType = str | list[str] +MEMORY_SQLITE = "sqlite://:memory:" __all__ = ( "MEMORY_SQLITE", - "SimpleTestCase", - "TestCase", - "TruncationTestCase", - "IsolatedTestCase", - "getDBConfig", + "TortoiseContext", + "tortoise_test_context", "requireCapability", - "env_initializer", - "initializer", - "finalizer", + "truncate_all_models", + "init_memory_sqlite", "SkipTest", "expectedFailure", "skip", "skipIf", "skipUnless", - "init_memory_sqlite", ) -_TORTOISE_TEST_DB = "sqlite://:memory:" -# pylint: disable=W0201 expectedFailure.__doc__ = """ Mark test as expecting failure. @@ -46,51 +62,20 @@ On success it will be marked as unexpected success. """ -_CONFIG: dict = {} -_CONNECTIONS: dict = {} -_LOOP: AbstractEventLoop = None # type: ignore -_MODULES: Iterable[str | ModuleType] = [] -_CONN_CONFIG: dict = {} - -def getDBConfig(app_label: str, modules: Iterable[str | ModuleType]) -> dict: - """ - DB Config factory, for use in testing. - - :param app_label: Label of the app (must be distinct for multiple apps). - :param modules: List of modules to look for models in. +async def truncate_all_models() -> None: """ - return _generate_config( - _TORTOISE_TEST_DB, - app_modules={app_label: modules}, - testing=True, - connection_label=app_label, - ) - - -async def _init_db(config: dict) -> None: - # Placing init outside the try block since it doesn't - # establish connections to the DB eagerly. - await Tortoise.init(config) - try: - await Tortoise._drop_databases() - except (DBConnectionError, OperationalError): # pragma: nocoverage - pass + Truncate all models in the current context. - await Tortoise.init(config, _create_db=True) - await Tortoise.generate_schemas(safe=False) + This is a utility function for test cleanup that deletes all rows from + all registered model tables. + Note: This is a naive implementation that may fail with M2M relations + and non-cascade foreign keys. -def _restore_default() -> None: - Tortoise.apps = None - connections._get_storage().update(_CONNECTIONS.copy()) - connections._db_config = _CONN_CONFIG.copy() - Tortoise._init_apps(_CONFIG["apps"]) - Tortoise._inited = True - - -async def truncate_all_models() -> None: - # TODO: This is a naive implementation: Will fail to clear M2M and non-cascade foreign keys + Raises: + ValueError: If Tortoise.apps is not loaded. + """ if not Tortoise.apps: raise ValueError("apps are not loaded") for model in Tortoise.apps.get_models_iterable(): @@ -100,236 +85,12 @@ async def truncate_all_models() -> None: ) -def initializer( - modules: Iterable[str | ModuleType], - db_url: str | None = None, - app_label: str = "models", - loop: AbstractEventLoop | None = None, -) -> None: - """ - Sets up the DB for testing. Must be called as part of test environment setup. - - :param modules: List of modules to look for models in. - :param db_url: The db_url, defaults to ``sqlite://:memory``. - :param app_label: The name of the APP to initialise the modules in, defaults to "models" - :param loop: Optional event loop. - """ - # pylint: disable=W0603 - global _CONFIG - global _CONNECTIONS - global _LOOP - global _TORTOISE_TEST_DB - global _MODULES - global _CONN_CONFIG - _MODULES = modules - if db_url is not None: # pragma: nobranch - _TORTOISE_TEST_DB = db_url - _CONFIG = getDBConfig(app_label=app_label, modules=_MODULES) - if not loop: - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - _LOOP = loop - loop.run_until_complete(_init_db(_CONFIG)) - _CONNECTIONS = connections._copy_storage() - _CONN_CONFIG = connections.db_config.copy() - connections._clear_storage() - connections.db_config.clear() - Tortoise.apps = None - Tortoise._inited = False - - -def finalizer() -> None: - """ - Cleans up the DB after testing. Must be called as part of the test environment teardown. - """ - _restore_default() - loop = _LOOP - loop.run_until_complete(Tortoise._drop_databases()) - - -def env_initializer() -> None: # pragma: nocoverage - """ - Calls ``initializer()`` with parameters mapped from environment variables. - - ``TORTOISE_TEST_MODULES``: - A comma-separated list of modules to include *(required)* - ``TORTOISE_TEST_APP``: - The name of the APP to initialise the modules in *(optional)* - - If not provided, it will default to "models". - ``TORTOISE_TEST_DB``: - The db_url of the test db. *(optional*) - - If not provided, it will default to an in-memory SQLite DB. - """ - modules = str(_os.environ.get("TORTOISE_TEST_MODULES", "tests.testmodels")).split(",") - db_url = _os.environ.get("TORTOISE_TEST_DB", "sqlite://:memory:") - app_label = _os.environ.get("TORTOISE_TEST_APP", "models") - if not modules: # pragma: nocoverage - raise Exception("TORTOISE_TEST_MODULES envvar not defined") - initializer(modules, db_url=db_url, app_label=app_label) - - -class SimpleTestCase(unittest.IsolatedAsyncioTestCase): - """ - The Tortoise base test class. - - This will ensure that your DB environment has a test double set up for use. - - An asyncio capable test class that provides some helper functions. - - Will run any ``test_*()`` function either as sync or async, depending - on the signature of the function. - If you specify ``async test_*()`` then it will run it in an event loop. +_FT = TypeVar("_FT", bound=Callable[..., typing.Any]) - Based on `asynctest `_ - """ - - def _setupAsyncioRunner(self) -> None: - if hasattr(asyncio, "Runner"): # For python3.11+ - runner = asyncio.Runner(debug=True, loop_factory=asyncio.get_event_loop) - self._asyncioRunner = runner - - def _tearDownAsyncioRunner(self) -> None: - # Override runner tear down to avoid eventloop closing before testing completed. - pass - - async def _setUpDB(self) -> None: - pass - - async def _tearDownDB(self) -> None: - pass - - def _setupAsyncioLoop(self): - loop = asyncio.get_event_loop() - loop.set_debug(True) - self._asyncioTestLoop = loop - fut = loop.create_future() - self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut)) # type: ignore - loop.run_until_complete(fut) - - def _tearDownAsyncioLoop(self): - loop = self._asyncioTestLoop - self._asyncioTestLoop = None # type: ignore - self._asyncioCallsQueue.put_nowait(None) # type: ignore - loop.run_until_complete(self._asyncioCallsQueue.join()) # type: ignore - - async def asyncSetUp(self) -> None: - self._reset_conn_state() - Tortoise.apps = None - Tortoise._inited = False - await self._setUpDB() - - def _reset_conn_state(self) -> None: - # clearing the storage and db config - connections._clear_storage() - connections.db_config.clear() - - async def asyncTearDown(self) -> None: - await self._tearDownDB() - self._reset_conn_state() - Tortoise.apps = None - Tortoise._inited = False - - def assertListSortEqual( - self, list1: list[Any], list2: list[Any], msg: Any = ..., sorted_key: str | None = None - ) -> None: - if isinstance(list1[0], Model): - super().assertListEqual( - sorted(list1, key=lambda x: x.pk), sorted(list2, key=lambda x: x.pk), msg=msg - ) - elif isinstance(list1[0], dict) and sorted_key: - super().assertListEqual( - sorted(list1, key=lambda x: x[sorted_key]), - sorted(list2, key=lambda x: x[sorted_key]), - msg=msg, - ) - else: - super().assertListEqual(sorted(list1), sorted(list2), msg=msg) - - -class IsolatedTestCase(SimpleTestCase): - """ - An asyncio capable test class that will ensure that an isolated test db - is available for each test. - - Use this if your test needs perfect isolation. - - Note to use ``{}`` as a string-replacement parameter, for your DB_URL. - That will create a randomised database name. - It will create and destroy a new DB instance for every test. - This is obviously slow, but guarantees a fresh DB. - - If you define a ``tortoise_test_modules`` list, it overrides the DB setup module for the tests. - """ - - tortoise_test_modules: Iterable[str | ModuleType] = [] - - async def _setUpDB(self) -> None: - await super()._setUpDB() - config = getDBConfig(app_label="models", modules=self.tortoise_test_modules or _MODULES) - await Tortoise.init(config, _create_db=True) - await Tortoise.generate_schemas(safe=False) - - async def _tearDownDB(self) -> None: - await Tortoise._drop_databases() - - -class TruncationTestCase(SimpleTestCase): - """ - An asyncio capable test class that will truncate the tables after a test. - - Use this when your tests contain transactions. - - This is slower than ``TestCase`` but faster than ``IsolatedTestCase``. - Note that usage of this does not guarantee that auto-number-pks will be reset to 1. - """ - - async def _setUpDB(self) -> None: - await super()._setUpDB() - _restore_default() - - async def _tearDownDB(self) -> None: - _restore_default() - await truncate_all_models() - await super()._tearDownDB() - - -class _RollbackException(Exception): - pass - - -class TestCase(TruncationTestCase): - """ - An asyncio capable test class that will ensure that each test will be run at - separate transaction that will rollback on finish. - - This is a fast test runner. Don't use it if your test uses transactions. - """ - - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - self._db = connections.get("models") - self._transaction = self._db._in_transaction() - await self._transaction.__aenter__() - - async def asyncTearDown(self) -> None: - # this will cause a rollback - await self._transaction.__aexit__(_RollbackException, _RollbackException(), None) - await super().asyncTearDown() - - async def _tearDownDB(self) -> None: - if self._db.capabilities.supports_transactions: - _restore_default() - else: - await super()._tearDownDB() - - -def requireCapability(connection_name: str = "models", **conditions: Any) -> Callable: +def requireCapability( + connection_name: str = "models", **conditions: typing.Any +) -> Callable[[_FT], _FT]: """ Skip a test if the required capabilities are not matched. @@ -341,7 +102,8 @@ def requireCapability(connection_name: str = "models", **conditions: Any) -> Cal .. code-block:: python3 @requireCapability(dialect='sqlite') - async def test_run_sqlite_only(self): + @pytest.mark.asyncio + async def test_run_sqlite_only(db): ... Or to conditionally skip a class: @@ -349,38 +111,39 @@ async def test_run_sqlite_only(self): .. code-block:: python3 @requireCapability(dialect='sqlite') - class TestSqlite(test.TestCase): - ... + class TestSqlite: + @pytest.mark.asyncio + async def test_something(self, db): + ... :param connection_name: name of the connection to retrieve capabilities from. :param conditions: capability tests which must all pass for the test to run. """ - def decorator(test_item): + def decorator(test_item: _FT) -> _FT: if not isinstance(test_item, type): def check_capabilities() -> None: - db = connections.get(connection_name) + db = get_connection(connection_name) for key, val in conditions.items(): if getattr(db.capabilities, key) != val: raise SkipTest(f"Capability {key} != {val}") - if hasattr(asyncio, "Runner") and inspect.iscoroutinefunction(test_item): - # For python3.11+ + if inspect.iscoroutinefunction(test_item): @wraps(test_item) - async def skip_wrapper(*args, **kwargs): + async def skip_wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: check_capabilities() return await test_item(*args, **kwargs) else: @wraps(test_item) - def skip_wrapper(*args, **kwargs): + def skip_wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: check_capabilities() return test_item(*args, **kwargs) - return skip_wrapper + return cast(_FT, skip_wrapper) # Assume a class is decorated funcs = { @@ -400,14 +163,6 @@ def skip_wrapper(*args, **kwargs): return decorator -T = TypeVar("T") -P = ParamSpec("P") -AsyncFunc = Callable[P, Coroutine[None, None, T]] -AsyncFuncDeco = Callable[..., AsyncFunc] -ModulesConfigType = str | list[str] -MEMORY_SQLITE = "sqlite://:memory:" - - @typing.overload def init_memory_sqlite(models: ModulesConfigType | None = None) -> AsyncFuncDeco: ... @@ -420,9 +175,11 @@ def init_memory_sqlite( models: ModulesConfigType | AsyncFunc | None = None, ) -> AsyncFunc | AsyncFuncDeco: """ - For single file style to run code with memory sqlite + Decorator for initializing Tortoise with an in-memory SQLite database. - :param models: list_of_modules that should be discovered for models, default to ['__main__']. + This is useful for simple scripts and examples that need a quick database setup. + + :param models: List of modules to load models from. Defaults to ["__main__"]. Usage: @@ -440,9 +197,8 @@ async def run(): obj = await MyModel.create(name='') assert obj.id == 1 - if __name__ == '__main__' - run_async(run) - + if __name__ == '__main__': + run_async(run()) Custom models example: diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index 3c456e9c8..bcfe73b13 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -232,7 +232,15 @@ def __init__( raise ConfigurationError( "TextField doesn't support unique indexes, consider CharField or another strategy" ) - if db_index or kwargs.get("index"): + if (index := kwargs.pop("index", None)) is not None: + warnings.warn( + "`index` is deprecated, please use `db_index` instead", + DeprecationWarning, + stacklevel=2, + ) + if index or db_index: + raise ConfigurationError("TextField can't be indexed, consider CharField") + elif db_index: raise ConfigurationError("TextField can't be indexed, consider CharField") super().__init__(primary_key=primary_key, **kwargs) diff --git a/tortoise/migrations/api/migrate.py b/tortoise/migrations/api/migrate.py index 641717bae..b1b579490 100644 --- a/tortoise/migrations/api/migrate.py +++ b/tortoise/migrations/api/migrate.py @@ -4,8 +4,9 @@ from collections.abc import Callable, Sequence from typing import Any -from tortoise import Tortoise, connections +from tortoise import Tortoise from tortoise.config import TortoiseConfig +from tortoise.connection import get_connection from tortoise.migrations.executor import MigrationExecutor, MigrationTarget, PlanStep @@ -44,7 +45,7 @@ async def migrate( targets = _parse_targets(target, selected_apps) for connection_name, subset in apps_by_connection.items(): - connection = connections.get(connection_name) + connection = get_connection(connection_name) executor = MigrationExecutor(connection, subset) executor_targets = [t for t in targets if t.app_label in subset] if reporter is not None: diff --git a/tortoise/migrations/api/plan.py b/tortoise/migrations/api/plan.py index 736ee11fe..2dce67f56 100644 --- a/tortoise/migrations/api/plan.py +++ b/tortoise/migrations/api/plan.py @@ -3,8 +3,9 @@ from collections.abc import Sequence from typing import Any -from tortoise import Tortoise, connections +from tortoise import Tortoise from tortoise.config import TortoiseConfig +from tortoise.connection import get_connection from tortoise.migrations.executor import MigrationExecutor, MigrationTarget, PlanStep @@ -42,7 +43,7 @@ async def plan( targets = _parse_targets(target, selected_apps) output: list[str] = [] for connection_name, subset in apps_by_connection.items(): - connection = connections.get(connection_name) + connection = get_connection(connection_name) executor = MigrationExecutor(connection, subset) executor_targets = [t for t in targets if t.app_label in subset] steps = await executor.plan(executor_targets if executor_targets else None) diff --git a/tortoise/migrations/schema_generator/state_apps.py b/tortoise/migrations/schema_generator/state_apps.py index 7d976063c..2f88b0455 100644 --- a/tortoise/migrations/schema_generator/state_apps.py +++ b/tortoise/migrations/schema_generator/state_apps.py @@ -5,12 +5,21 @@ from pypika_tortoise import Query, Table from tortoise.apps import Apps -from tortoise.connection import connections +from tortoise.connection import ConnectionHandler +from tortoise.context import get_current_context from tortoise.models import Model class StateApps(Apps): - def __init__(self, default_connections: dict[str, str] | None = None) -> None: + def __init__( + self, + default_connections: dict[str, str] | None = None, + connections: ConnectionHandler | None = None, + ) -> None: + if connections is None: + ctx = get_current_context() + connections = ctx.connections if ctx is not None else ConnectionHandler() + super().__init__({}, connections) self._default_connections = default_connections or {} @@ -27,6 +36,11 @@ def register_model(self, app_label: str, model: type[Model]) -> None: model._meta.default_connection = self._default_connections[app_label] def _build_initial_querysets(self) -> None: + # Skip building querysets when no DB config is available (state-only mode) + # This allows pure state operations to work without database connections + if self._connections._db_config is None: + return + for app in self.apps.values(): for model in app.values(): if model._meta.default_connection is None: @@ -67,7 +81,10 @@ def get_model(self, app_label: str, model_name: str | None = None) -> type[Model def clone(self) -> StateApps: from tortoise.migrations.schema_generator.state import ModelState - state_apps = self.__class__(default_connections=dict(self._default_connections)) + state_apps = self.__class__( + default_connections=dict(self._default_connections), + connections=self._connections, + ) for app_label, app in self.apps.items(): for model in app.values(): model_clone = ModelState.make_from_model(app_label, model).render(state_apps) diff --git a/tortoise/models.py b/tortoise/models.py index c536a4fda..12f1ecdd3 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -13,7 +13,7 @@ from pypika_tortoise.terms import Term from tortoise.backends.base.client import BaseDBAsyncClient -from tortoise.connection import connections +from tortoise.connection import get_connection from tortoise.exceptions import ( ConfigurationError, DoesNotExist, @@ -286,7 +286,7 @@ def db(self) -> BaseDBAsyncClient: raise ConfigurationError( f"default_connection for the model {self._model} cannot be None" ) - return connections.get(self.default_connection) + return get_connection(self.default_connection) @property def ordering(self) -> tuple[tuple[str, Order], ...]: diff --git a/tortoise/query_api.py b/tortoise/query_api.py index 99b4acf25..a4b2a95eb 100644 --- a/tortoise/query_api.py +++ b/tortoise/query_api.py @@ -7,7 +7,7 @@ from pypika_tortoise.terms import Parameterizer from tortoise.backends.base.client import BaseDBAsyncClient -from tortoise.connection import connections +from tortoise.connection import get_connections from tortoise.exceptions import ParamsError if TYPE_CHECKING: @@ -87,14 +87,16 @@ async def execute_pypika( ) -> QueryResult[SchemaT] | QueryResult[dict]: if using_db is not None: db = using_db - elif len(connections.db_config) == 1: - connection_name = next(iter(connections.db_config.keys())) - db = connections.get(connection_name) else: - raise ParamsError( - "You are running with multiple databases, so you should specify" - f" connection_name: {list(connections.db_config)}" - ) + conn_handler = get_connections() + if len(conn_handler.db_config) == 1: + connection_name = next(iter(conn_handler.db_config.keys())) + db = conn_handler.get(connection_name) + else: + raise ParamsError( + "You are running with multiple databases, so you should specify" + f" connection_name: {list(conn_handler.db_config)}" + ) sql, params = query.get_parameterized_sql(_get_sql_context(db)) rows, rows_affected = await db.execute_query_dict_with_affected(sql, params) diff --git a/tortoise/queryset.py b/tortoise/queryset.py index bb7eb69bc..3d32e63b0 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -84,7 +84,7 @@ class AwaitableQuery(Generic[MODEL]): "model", "_joined_tables", "_db", - "capabilities", + "_capabilities", "_annotations", "_custom_filters", "_q_objects", @@ -95,11 +95,21 @@ def __init__(self, model: type[MODEL]) -> None: self.model: type[MODEL] = model self.query: QueryBuilder = QUERY self._db: BaseDBAsyncClient = None # type: ignore - self.capabilities: Capabilities = model._meta.db.capabilities + self._capabilities: Capabilities | None = None self._annotations: dict[str, Expression | Term] = {} self._custom_filters: dict[str, FilterInfoDict] = {} self._q_objects: list[Q] = [] + @property + def capabilities(self) -> Capabilities: + if self._capabilities is None: + self._capabilities = self.model._meta.db.capabilities + return self._capabilities + + @capabilities.setter + def capabilities(self, value: Capabilities) -> None: + self._capabilities = value + def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient: """ Return the connection that will be used if this query is executed now. diff --git a/tortoise/router.py b/tortoise/router.py index 612db31c9..4581055a7 100644 --- a/tortoise/router.py +++ b/tortoise/router.py @@ -3,7 +3,7 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any -from tortoise.connection import connections +from tortoise.connection import get_connection from tortoise.exceptions import ConfigurationError if TYPE_CHECKING: @@ -31,7 +31,7 @@ def _router_func(self, model: type[Model], action: str) -> Any: def _db_route(self, model: type[Model], action: str) -> BaseDBAsyncClient | None: try: - return connections.get(self._router_func(model, action)) + return get_connection(self._router_func(model, action)) except ConfigurationError: return None diff --git a/tortoise/transactions.py b/tortoise/transactions.py index cf2cc0bc3..d92b32db4 100644 --- a/tortoise/transactions.py +++ b/tortoise/transactions.py @@ -4,7 +4,7 @@ from functools import wraps from typing import TYPE_CHECKING, TypeVar, cast -from tortoise.connection import connections +from tortoise.connection import get_connections from tortoise.exceptions import ParamsError if TYPE_CHECKING: # pragma: nocoverage @@ -16,15 +16,16 @@ def _get_connection(connection_name: str | None) -> BaseDBAsyncClient: + conn_handler = get_connections() if connection_name: - connection = connections.get(connection_name) - elif len(connections.db_config) == 1: - connection_name = next(iter(connections.db_config.keys())) - connection = connections.get(connection_name) + connection = conn_handler.get(connection_name) + elif len(conn_handler.db_config) == 1: + connection_name = next(iter(conn_handler.db_config.keys())) + connection = conn_handler.get(connection_name) else: raise ParamsError( "You are running with multiple databases, so you should specify" - f" connection_name: {list(connections.db_config)}" + f" connection_name: {list(conn_handler.db_config)}" ) return connection diff --git a/uv.lock b/uv.lock index faa25cb43..b48a489c3 100644 --- a/uv.lock +++ b/uv.lock @@ -384,11 +384,11 @@ wheels = [ [[package]] name = "babel" -version = "2.17.0" +version = "2.18.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/7d/6b/d52e42361e1aa00709585ecc30b3f9684b3ab62530771402248b1b1d6240/babel-2.17.0.tar.gz", hash = "sha256:0c54cffb19f690cdcc52a3b50bcbf71e07a808d1c80d549f2459b9d2cf0afb9d", size = 9951852, upload-time = "2025-02-01T15:17:41.026Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/b2/51899539b6ceeeb420d40ed3cd4b7a40519404f9baf3d4ac99dc413a834b/babel-2.18.0.tar.gz", hash = "sha256:b80b99a14bd085fcacfa15c9165f651fbb3406e66cc603abf11c5750937c992d", size = 9959554, upload-time = "2026-02-01T12:30:56.078Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537, upload-time = "2025-02-01T15:17:37.39Z" }, + { url = "https://files.pythonhosted.org/packages/77/f5/21d2de20e8b8b0408f0681956ca2c69f1320a3848ac50e6e7f39c6159675/babel-2.18.0-py3-none-any.whl", hash = "sha256:e2b422b277c2b9a9630c1d7903c2a00d0830c409c59ac8cae9081c92f1aeba35", size = 10196845, upload-time = "2026-02-01T12:30:53.445Z" }, ] [[package]] @@ -860,37 +860,54 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/78/19/f748958276519adf6a0c1e79e7b8860b4830dda55ccdf29f2719b5fc499c/cryptography-46.0.4.tar.gz", hash = "sha256:bfd019f60f8abc2ed1b9be4ddc21cfef059c841d86d710bb69909a688cbb8f59", size = 749301, upload-time = "2026-01-28T00:24:37.379Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/99/157aae7949a5f30d51fcb1a9851e8ebd5c74bf99b5285d8bb4b8b9ee641e/cryptography-46.0.4-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:281526e865ed4166009e235afadf3a4c4cba6056f99336a99efba65336fd5485", size = 7173686, upload-time = "2026-01-28T00:23:07.515Z" }, { url = "https://files.pythonhosted.org/packages/87/91/874b8910903159043b5c6a123b7e79c4559ddd1896e38967567942635778/cryptography-46.0.4-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f14fba5bf6f4390d7ff8f086c566454bff0411f6d8aa7af79c88b6f9267aecc", size = 4275871, upload-time = "2026-01-28T00:23:09.439Z" }, { url = "https://files.pythonhosted.org/packages/c0/35/690e809be77896111f5b195ede56e4b4ed0435b428c2f2b6d35046fbb5e8/cryptography-46.0.4-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:47bcd19517e6389132f76e2d5303ded6cf3f78903da2158a671be8de024f4cd0", size = 4423124, upload-time = "2026-01-28T00:23:11.529Z" }, { url = "https://files.pythonhosted.org/packages/1a/5b/a26407d4f79d61ca4bebaa9213feafdd8806dc69d3d290ce24996d3cfe43/cryptography-46.0.4-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:01df4f50f314fbe7009f54046e908d1754f19d0c6d3070df1e6268c5a4af09fa", size = 4277090, upload-time = "2026-01-28T00:23:13.123Z" }, + { url = "https://files.pythonhosted.org/packages/0c/d8/4bb7aec442a9049827aa34cee1aa83803e528fa55da9a9d45d01d1bb933e/cryptography-46.0.4-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5aa3e463596b0087b3da0dbe2b2487e9fc261d25da85754e30e3b40637d61f81", size = 4947652, upload-time = "2026-01-28T00:23:14.554Z" }, { url = "https://files.pythonhosted.org/packages/2b/08/f83e2e0814248b844265802d081f2fac2f1cbe6cd258e72ba14ff006823a/cryptography-46.0.4-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0a9ad24359fee86f131836a9ac3bffc9329e956624a2d379b613f8f8abaf5255", size = 4455157, upload-time = "2026-01-28T00:23:16.443Z" }, { url = "https://files.pythonhosted.org/packages/0a/05/19d849cf4096448779d2dcc9bb27d097457dac36f7273ffa875a93b5884c/cryptography-46.0.4-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:dc1272e25ef673efe72f2096e92ae39dea1a1a450dd44918b15351f72c5a168e", size = 3981078, upload-time = "2026-01-28T00:23:17.838Z" }, { url = "https://files.pythonhosted.org/packages/e6/89/f7bac81d66ba7cde867a743ea5b37537b32b5c633c473002b26a226f703f/cryptography-46.0.4-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:de0f5f4ec8711ebc555f54735d4c673fc34b65c44283895f1a08c2b49d2fd99c", size = 4276213, upload-time = "2026-01-28T00:23:19.257Z" }, + { url = "https://files.pythonhosted.org/packages/da/9f/7133e41f24edd827020ad21b068736e792bc68eecf66d93c924ad4719fb3/cryptography-46.0.4-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:eeeb2e33d8dbcccc34d64651f00a98cb41b2dc69cef866771a5717e6734dfa32", size = 4912190, upload-time = "2026-01-28T00:23:21.244Z" }, { url = "https://files.pythonhosted.org/packages/a6/f7/6d43cbaddf6f65b24816e4af187d211f0bc536a29961f69faedc48501d8e/cryptography-46.0.4-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3d425eacbc9aceafd2cb429e42f4e5d5633c6f873f5e567077043ef1b9bbf616", size = 4454641, upload-time = "2026-01-28T00:23:22.866Z" }, { url = "https://files.pythonhosted.org/packages/9e/4f/ebd0473ad656a0ac912a16bd07db0f5d85184924e14fc88feecae2492834/cryptography-46.0.4-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91627ebf691d1ea3976a031b61fb7bac1ccd745afa03602275dda443e11c8de0", size = 4405159, upload-time = "2026-01-28T00:23:25.278Z" }, { url = "https://files.pythonhosted.org/packages/d1/f7/7923886f32dc47e27adeff8246e976d77258fd2aa3efdd1754e4e323bf49/cryptography-46.0.4-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:2d08bc22efd73e8854b0b7caff402d735b354862f1145d7be3b9c0f740fef6a0", size = 4666059, upload-time = "2026-01-28T00:23:26.766Z" }, + { url = "https://files.pythonhosted.org/packages/eb/a7/0fca0fd3591dffc297278a61813d7f661a14243dd60f499a7a5b48acb52a/cryptography-46.0.4-cp311-abi3-win32.whl", hash = "sha256:82a62483daf20b8134f6e92898da70d04d0ef9a75829d732ea1018678185f4f5", size = 3026378, upload-time = "2026-01-28T00:23:28.317Z" }, + { url = "https://files.pythonhosted.org/packages/2d/12/652c84b6f9873f0909374864a57b003686c642ea48c84d6c7e2c515e6da5/cryptography-46.0.4-cp311-abi3-win_amd64.whl", hash = "sha256:6225d3ebe26a55dbc8ead5ad1265c0403552a63336499564675b29eb3184c09b", size = 3478614, upload-time = "2026-01-28T00:23:30.275Z" }, + { url = "https://files.pythonhosted.org/packages/b9/27/542b029f293a5cce59349d799d4d8484b3b1654a7b9a0585c266e974a488/cryptography-46.0.4-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:485e2b65d25ec0d901bca7bcae0f53b00133bf3173916d8e421f6fddde103908", size = 7116417, upload-time = "2026-01-28T00:23:31.958Z" }, { url = "https://files.pythonhosted.org/packages/f8/f5/559c25b77f40b6bf828eabaf988efb8b0e17b573545edb503368ca0a2a03/cryptography-46.0.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:078e5f06bd2fa5aea5a324f2a09f914b1484f1d0c2a4d6a8a28c74e72f65f2da", size = 4264508, upload-time = "2026-01-28T00:23:34.264Z" }, { url = "https://files.pythonhosted.org/packages/49/a1/551fa162d33074b660dc35c9bc3616fefa21a0e8c1edd27b92559902e408/cryptography-46.0.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dce1e4f068f03008da7fa51cc7abc6ddc5e5de3e3d1550334eaf8393982a5829", size = 4409080, upload-time = "2026-01-28T00:23:35.793Z" }, { url = "https://files.pythonhosted.org/packages/b0/6a/4d8d129a755f5d6df1bbee69ea2f35ebfa954fa1847690d1db2e8bca46a5/cryptography-46.0.4-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:2067461c80271f422ee7bdbe79b9b4be54a5162e90345f86a23445a0cf3fd8a2", size = 4270039, upload-time = "2026-01-28T00:23:37.263Z" }, + { url = "https://files.pythonhosted.org/packages/4c/f5/ed3fcddd0a5e39321e595e144615399e47e7c153a1fb8c4862aec3151ff9/cryptography-46.0.4-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:c92010b58a51196a5f41c3795190203ac52edfd5dc3ff99149b4659eba9d2085", size = 4926748, upload-time = "2026-01-28T00:23:38.884Z" }, { url = "https://files.pythonhosted.org/packages/43/ae/9f03d5f0c0c00e85ecb34f06d3b79599f20630e4db91b8a6e56e8f83d410/cryptography-46.0.4-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:829c2b12bbc5428ab02d6b7f7e9bbfd53e33efd6672d21341f2177470171ad8b", size = 4442307, upload-time = "2026-01-28T00:23:40.56Z" }, { url = "https://files.pythonhosted.org/packages/8b/22/e0f9f2dae8040695103369cf2283ef9ac8abe4d51f68710bec2afd232609/cryptography-46.0.4-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:62217ba44bf81b30abaeda1488686a04a702a261e26f87db51ff61d9d3510abd", size = 3959253, upload-time = "2026-01-28T00:23:42.827Z" }, { url = "https://files.pythonhosted.org/packages/01/5b/6a43fcccc51dae4d101ac7d378a8724d1ba3de628a24e11bf2f4f43cba4d/cryptography-46.0.4-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:9c2da296c8d3415b93e6053f5a728649a87a48ce084a9aaf51d6e46c87c7f2d2", size = 4269372, upload-time = "2026-01-28T00:23:44.655Z" }, + { url = "https://files.pythonhosted.org/packages/17/b7/0f6b8c1dd0779df2b526e78978ff00462355e31c0a6f6cff8a3e99889c90/cryptography-46.0.4-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:9b34d8ba84454641a6bf4d6762d15847ecbd85c1316c0a7984e6e4e9f748ec2e", size = 4891908, upload-time = "2026-01-28T00:23:46.48Z" }, { url = "https://files.pythonhosted.org/packages/83/17/259409b8349aa10535358807a472c6a695cf84f106022268d31cea2b6c97/cryptography-46.0.4-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:df4a817fa7138dd0c96c8c8c20f04b8aaa1fac3bbf610913dcad8ea82e1bfd3f", size = 4441254, upload-time = "2026-01-28T00:23:48.403Z" }, { url = "https://files.pythonhosted.org/packages/9c/fe/e4a1b0c989b00cee5ffa0764401767e2d1cf59f45530963b894129fd5dce/cryptography-46.0.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:b1de0ebf7587f28f9190b9cb526e901bf448c9e6a99655d2b07fff60e8212a82", size = 4396520, upload-time = "2026-01-28T00:23:50.26Z" }, { url = "https://files.pythonhosted.org/packages/b3/81/ba8fd9657d27076eb40d6a2f941b23429a3c3d2f56f5a921d6b936a27bc9/cryptography-46.0.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9b4d17bc7bd7cdd98e3af40b441feaea4c68225e2eb2341026c84511ad246c0c", size = 4651479, upload-time = "2026-01-28T00:23:51.674Z" }, + { url = "https://files.pythonhosted.org/packages/00/03/0de4ed43c71c31e4fe954edd50b9d28d658fef56555eba7641696370a8e2/cryptography-46.0.4-cp314-cp314t-win32.whl", hash = "sha256:c411f16275b0dea722d76544a61d6421e2cc829ad76eec79280dbdc9ddf50061", size = 3001986, upload-time = "2026-01-28T00:23:53.485Z" }, + { url = "https://files.pythonhosted.org/packages/5c/70/81830b59df7682917d7a10f833c4dab2a5574cd664e86d18139f2b421329/cryptography-46.0.4-cp314-cp314t-win_amd64.whl", hash = "sha256:728fedc529efc1439eb6107b677f7f7558adab4553ef8669f0d02d42d7b959a7", size = 3468288, upload-time = "2026-01-28T00:23:55.09Z" }, + { url = "https://files.pythonhosted.org/packages/56/f7/f648fdbb61d0d45902d3f374217451385edc7e7768d1b03ff1d0e5ffc17b/cryptography-46.0.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a9556ba711f7c23f77b151d5798f3ac44a13455cc68db7697a1096e6d0563cab", size = 7169583, upload-time = "2026-01-28T00:23:56.558Z" }, { url = "https://files.pythonhosted.org/packages/d8/cc/8f3224cbb2a928de7298d6ed4790f5ebc48114e02bdc9559196bfb12435d/cryptography-46.0.4-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8bf75b0259e87fa70bddc0b8b4078b76e7fd512fd9afae6c1193bcf440a4dbef", size = 4275419, upload-time = "2026-01-28T00:23:58.364Z" }, { url = "https://files.pythonhosted.org/packages/17/43/4a18faa7a872d00e4264855134ba82d23546c850a70ff209e04ee200e76f/cryptography-46.0.4-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3c268a3490df22270955966ba236d6bc4a8f9b6e4ffddb78aac535f1a5ea471d", size = 4419058, upload-time = "2026-01-28T00:23:59.867Z" }, { url = "https://files.pythonhosted.org/packages/ee/64/6651969409821d791ba12346a124f55e1b76f66a819254ae840a965d4b9c/cryptography-46.0.4-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:812815182f6a0c1d49a37893a303b44eaac827d7f0d582cecfc81b6427f22973", size = 4278151, upload-time = "2026-01-28T00:24:01.731Z" }, + { url = "https://files.pythonhosted.org/packages/20/0b/a7fce65ee08c3c02f7a8310cc090a732344066b990ac63a9dfd0a655d321/cryptography-46.0.4-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:a90e43e3ef65e6dcf969dfe3bb40cbf5aef0d523dff95bfa24256be172a845f4", size = 4939441, upload-time = "2026-01-28T00:24:03.175Z" }, { url = "https://files.pythonhosted.org/packages/db/a7/20c5701e2cd3e1dfd7a19d2290c522a5f435dd30957d431dcb531d0f1413/cryptography-46.0.4-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a05177ff6296644ef2876fce50518dffb5bcdf903c85250974fc8bc85d54c0af", size = 4451617, upload-time = "2026-01-28T00:24:05.403Z" }, { url = "https://files.pythonhosted.org/packages/00/dc/3e16030ea9aa47b63af6524c354933b4fb0e352257c792c4deeb0edae367/cryptography-46.0.4-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:daa392191f626d50f1b136c9b4cf08af69ca8279d110ea24f5c2700054d2e263", size = 3977774, upload-time = "2026-01-28T00:24:06.851Z" }, { url = "https://files.pythonhosted.org/packages/42/c8/ad93f14118252717b465880368721c963975ac4b941b7ef88f3c56bf2897/cryptography-46.0.4-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e07ea39c5b048e085f15923511d8121e4a9dc45cee4e3b970ca4f0d338f23095", size = 4277008, upload-time = "2026-01-28T00:24:08.926Z" }, + { url = "https://files.pythonhosted.org/packages/00/cf/89c99698151c00a4631fbfcfcf459d308213ac29e321b0ff44ceeeac82f1/cryptography-46.0.4-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:d5a45ddc256f492ce42a4e35879c5e5528c09cd9ad12420828c972951d8e016b", size = 4903339, upload-time = "2026-01-28T00:24:12.009Z" }, { url = "https://files.pythonhosted.org/packages/03/c3/c90a2cb358de4ac9309b26acf49b2a100957e1ff5cc1e98e6c4996576710/cryptography-46.0.4-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:6bb5157bf6a350e5b28aee23beb2d84ae6f5be390b2f8ee7ea179cda077e1019", size = 4451216, upload-time = "2026-01-28T00:24:13.975Z" }, { url = "https://files.pythonhosted.org/packages/96/2c/8d7f4171388a10208671e181ca43cdc0e596d8259ebacbbcfbd16de593da/cryptography-46.0.4-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:dd5aba870a2c40f87a3af043e0dee7d9eb02d4aff88a797b48f2b43eff8c3ab4", size = 4404299, upload-time = "2026-01-28T00:24:16.169Z" }, { url = "https://files.pythonhosted.org/packages/e9/23/cbb2036e450980f65c6e0a173b73a56ff3bccd8998965dea5cc9ddd424a5/cryptography-46.0.4-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:93d8291da8d71024379ab2cb0b5c57915300155ad42e07f76bea6ad838d7e59b", size = 4664837, upload-time = "2026-01-28T00:24:17.629Z" }, + { url = "https://files.pythonhosted.org/packages/0a/21/f7433d18fe6d5845329cbdc597e30caf983229c7a245bcf54afecc555938/cryptography-46.0.4-cp38-abi3-win32.whl", hash = "sha256:0563655cb3c6d05fb2afe693340bc050c30f9f34e15763361cf08e94749401fc", size = 3009779, upload-time = "2026-01-28T00:24:20.198Z" }, + { url = "https://files.pythonhosted.org/packages/3a/6a/bd2e7caa2facffedf172a45c1a02e551e6d7d4828658c9a245516a598d94/cryptography-46.0.4-cp38-abi3-win_amd64.whl", hash = "sha256:fa0900b9ef9c49728887d1576fd8d9e7e3ea872fa9b25ef9b64888adc434e976", size = 3466633, upload-time = "2026-01-28T00:24:21.851Z" }, + { url = "https://files.pythonhosted.org/packages/59/e0/f9c6c53e1f2a1c2507f00f2faba00f01d2f334b35b0fbfe5286715da2184/cryptography-46.0.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:766330cce7416c92b5e90c3bb71b1b79521760cdcfc3a6a1a182d4c9fab23d2b", size = 3476316, upload-time = "2026-01-28T00:24:24.144Z" }, { url = "https://files.pythonhosted.org/packages/27/7a/f8d2d13227a9a1a9fe9c7442b057efecffa41f1e3c51d8622f26b9edbe8f/cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c236a44acfb610e70f6b3e1c3ca20ff24459659231ef2f8c48e879e2d32b73da", size = 4216693, upload-time = "2026-01-28T00:24:25.758Z" }, { url = "https://files.pythonhosted.org/packages/c5/de/3787054e8f7972658370198753835d9d680f6cd4a39df9f877b57f0dd69c/cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:8a15fb869670efa8f83cbffbc8753c1abf236883225aed74cd179b720ac9ec80", size = 4382765, upload-time = "2026-01-28T00:24:27.577Z" }, { url = "https://files.pythonhosted.org/packages/8a/5f/60e0afb019973ba6a0b322e86b3d61edf487a4f5597618a430a2a15f2d22/cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:fdc3daab53b212472f1524d070735b2f0c214239df131903bae1d598016fa822", size = 4216066, upload-time = "2026-01-28T00:24:29.056Z" }, { url = "https://files.pythonhosted.org/packages/81/8e/bf4a0de294f147fee66f879d9bae6f8e8d61515558e3d12785dd90eca0be/cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:44cc0675b27cadb71bdbb96099cca1fa051cd11d2ade09e5cd3a2edb929ed947", size = 4382025, upload-time = "2026-01-28T00:24:30.681Z" }, + { url = "https://files.pythonhosted.org/packages/79/f4/9ceb90cfd6a3847069b0b0b353fd3075dc69b49defc70182d8af0c4ca390/cryptography-46.0.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:be8c01a7d5a55f9a47d1888162b76c8f49d62b234d88f0ff91a9fbebe32ffbc3", size = 3406043, upload-time = "2026-01-28T00:24:32.236Z" }, ] [[package]] @@ -1330,9 +1347,9 @@ wheels = [ name = "iso8601" version = "2.1.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b9/f3/ef59cee614d5e0accf6fd0cbba025b93b272e626ca89fb70a3e9187c5d15/iso8601-2.1.0.tar.gz", hash = "sha256:6b1d3829ee8921c4301998c909f7829fa9ed3cbdac0d3b16af2d743aed1ba8df" } +sdist = { url = "https://files.pythonhosted.org/packages/b9/f3/ef59cee614d5e0accf6fd0cbba025b93b272e626ca89fb70a3e9187c5d15/iso8601-2.1.0.tar.gz", hash = "sha256:6b1d3829ee8921c4301998c909f7829fa9ed3cbdac0d3b16af2d743aed1ba8df", size = 6522, upload-time = "2023-10-03T00:25:39.317Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6c/0c/f37b6a241f0759b7653ffa7213889d89ad49a2b76eb2ddf3b57b2738c347/iso8601-2.1.0-py3-none-any.whl", hash = "sha256:aac4145c4dcb66ad8b648a02830f5e2ff6c24af20f4f482689be402db2429242" }, + { url = "https://files.pythonhosted.org/packages/6c/0c/f37b6a241f0759b7653ffa7213889d89ad49a2b76eb2ddf3b57b2738c347/iso8601-2.1.0-py3-none-any.whl", hash = "sha256:aac4145c4dcb66ad8b648a02830f5e2ff6c24af20f4f482689be402db2429242", size = 7545, upload-time = "2023-10-03T00:25:32.304Z" }, ] [[package]] @@ -1865,83 +1882,83 @@ wheels = [ [[package]] name = "orjson" -version = "3.11.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/70/a3/4e09c61a5f0c521cba0bb433639610ae037437669f1a4cbc93799e731d78/orjson-3.11.6.tar.gz", hash = "sha256:0a54c72259f35299fd033042367df781c2f66d10252955ca1efb7db309b954cb", size = 6175856, upload-time = "2026-01-29T15:13:07.942Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/30/3c/098ed0e49c565fdf1ccc6a75b190115d1ca74148bf5b6ab036554a550650/orjson-3.11.6-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:a613fc37e007143d5b6286dccb1394cd114b07832417006a02b620ddd8279e37", size = 250411, upload-time = "2026-01-29T15:11:17.941Z" }, - { url = "https://files.pythonhosted.org/packages/15/7c/cb11a360fd228ceebade03b1e8e9e138dd4b1b3b11602b72dbdad915aded/orjson-3.11.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46ebee78f709d3ba7a65384cfe285bb0763157c6d2f836e7bde2f12d33a867a2", size = 138147, upload-time = "2026-01-29T15:11:19.659Z" }, - { url = "https://files.pythonhosted.org/packages/4e/4b/e57b5c45ffe69fbef7cbd56e9f40e2dc0d5de920caafefcc6981d1a7efc5/orjson-3.11.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a726fa86d2368cd57990f2bd95ef5495a6e613b08fc9585dfe121ec758fb08d1", size = 135110, upload-time = "2026-01-29T15:11:21.231Z" }, - { url = "https://files.pythonhosted.org/packages/b0/6e/4f21c6256f8cee3c0c69926cf7ac821cfc36f218512eedea2e2dc4a490c8/orjson-3.11.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:150f12e59d6864197770c78126e1a6e07a3da73d1728731bf3bc1e8b96ffdbe6", size = 140995, upload-time = "2026-01-29T15:11:22.902Z" }, - { url = "https://files.pythonhosted.org/packages/d0/78/92c36205ba2f6094ba1eea60c8e646885072abe64f155196833988c14b74/orjson-3.11.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a2d9746a5b5ce20c0908ada451eb56da4ffa01552a50789a0354d8636a02953", size = 144435, upload-time = "2026-01-29T15:11:24.124Z" }, - { url = "https://files.pythonhosted.org/packages/4d/52/1b518d164005811eb3fea92650e76e7d9deadb0b41e92c483373b1e82863/orjson-3.11.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:afd177f5dd91666d31e9019f1b06d2fcdf8a409a1637ddcb5915085dede85680", size = 142734, upload-time = "2026-01-29T15:11:25.708Z" }, - { url = "https://files.pythonhosted.org/packages/4b/11/60ea7885a2b7c1bf60ed8b5982356078a73785bd3bab392041a5bcf8de7c/orjson-3.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d777ec41a327bd3b7de97ba7bce12cc1007815ca398e4e4de9ec56c022c090b", size = 145802, upload-time = "2026-01-29T15:11:26.917Z" }, - { url = "https://files.pythonhosted.org/packages/41/7f/15a927e7958fd4f7560fb6dbb9346bee44a168e40168093c46020d866098/orjson-3.11.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f3a135f83185c87c13ff231fcb7dbb2fa4332a376444bd65135b50ff4cc5265c", size = 147504, upload-time = "2026-01-29T15:11:28.07Z" }, - { url = "https://files.pythonhosted.org/packages/66/1f/cabb9132a533f4f913e29294d0a1ca818b1a9a52e990526fe3f7ddd75f1c/orjson-3.11.6-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:2a8eeed7d4544cf391a142b0dd06029dac588e96cc692d9ab1c3f05b1e57c7f6", size = 421408, upload-time = "2026-01-29T15:11:29.314Z" }, - { url = "https://files.pythonhosted.org/packages/4c/b9/09bda9257a982e300313e4a9fc9b9c3aaff424d07bcf765bf045e4e3ed03/orjson-3.11.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:9d576865a21e5cc6695be8fb78afc812079fd361ce6a027a7d41561b61b33a90", size = 155801, upload-time = "2026-01-29T15:11:30.575Z" }, - { url = "https://files.pythonhosted.org/packages/98/19/4e40ea3e5f4c6a8d51f31fd2382351ee7b396fecca915b17cd1af588175b/orjson-3.11.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:925e2df51f60aa50f8797830f2adfc05330425803f4105875bb511ced98b7f89", size = 147647, upload-time = "2026-01-29T15:11:31.856Z" }, - { url = "https://files.pythonhosted.org/packages/5a/73/ef4bd7dd15042cf33a402d16b87b9e969e71edb452b63b6e2b05025d1f7d/orjson-3.11.6-cp310-cp310-win32.whl", hash = "sha256:09dded2de64e77ac0b312ad59f35023548fb87393a57447e1bb36a26c181a90f", size = 139770, upload-time = "2026-01-29T15:11:33.031Z" }, - { url = "https://files.pythonhosted.org/packages/b4/ac/daab6e10467f7fffd7081ba587b492505b49313130ff5446a6fe28bf076e/orjson-3.11.6-cp310-cp310-win_amd64.whl", hash = "sha256:3a63b5e7841ca8635214c6be7c0bf0246aa8c5cd4ef0c419b14362d0b2fb13de", size = 136783, upload-time = "2026-01-29T15:11:34.686Z" }, - { url = "https://files.pythonhosted.org/packages/f3/fd/d6b0a36854179b93ed77839f107c4089d91cccc9f9ba1b752b6e3bac5f34/orjson-3.11.6-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e259e85a81d76d9665f03d6129e09e4435531870de5961ddcd0bf6e3a7fde7d7", size = 250029, upload-time = "2026-01-29T15:11:35.942Z" }, - { url = "https://files.pythonhosted.org/packages/a3/bb/22902619826641cf3b627c24aab62e2ad6b571bdd1d34733abb0dd57f67a/orjson-3.11.6-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:52263949f41b4a4822c6b1353bcc5ee2f7109d53a3b493501d3369d6d0e7937a", size = 134518, upload-time = "2026-01-29T15:11:37.347Z" }, - { url = "https://files.pythonhosted.org/packages/72/90/7a818da4bba1de711a9653c420749c0ac95ef8f8651cbc1dca551f462fe0/orjson-3.11.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6439e742fa7834a24698d358a27346bb203bff356ae0402e7f5df8f749c621a8", size = 137917, upload-time = "2026-01-29T15:11:38.511Z" }, - { url = "https://files.pythonhosted.org/packages/59/0f/02846c1cac8e205cb3822dd8aa8f9114acda216f41fd1999ace6b543418d/orjson-3.11.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b81ffd68f084b4e993e3867acb554a049fa7787cc8710bbcc1e26965580d99be", size = 134923, upload-time = "2026-01-29T15:11:39.711Z" }, - { url = "https://files.pythonhosted.org/packages/94/cf/aeaf683001b474bb3c3c757073a4231dfdfe8467fceaefa5bfd40902c99f/orjson-3.11.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a5a5468e5e60f7ef6d7f9044b06c8f94a3c56ba528c6e4f7f06ae95164b595ec", size = 140752, upload-time = "2026-01-29T15:11:41.347Z" }, - { url = "https://files.pythonhosted.org/packages/fc/fe/dad52d8315a65f084044a0819d74c4c9daf9ebe0681d30f525b0d29a31f0/orjson-3.11.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:72c5005eb45bd2535632d4f3bec7ad392832cfc46b62a3021da3b48a67734b45", size = 144201, upload-time = "2026-01-29T15:11:42.537Z" }, - { url = "https://files.pythonhosted.org/packages/36/bc/ab070dd421565b831801077f1e390c4d4af8bfcecafc110336680a33866b/orjson-3.11.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b14dd49f3462b014455a28a4d810d3549bf990567653eb43765cd847df09145", size = 142380, upload-time = "2026-01-29T15:11:44.309Z" }, - { url = "https://files.pythonhosted.org/packages/e6/d8/4b581c725c3a308717f28bf45a9fdac210bca08b67e8430143699413ff06/orjson-3.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e0bb2c1ea30ef302f0f89f9bf3e7f9ab5e2af29dc9f80eb87aa99788e4e2d65", size = 145582, upload-time = "2026-01-29T15:11:45.506Z" }, - { url = "https://files.pythonhosted.org/packages/5b/a2/09aab99b39f9a7f175ea8fa29adb9933a3d01e7d5d603cdee7f1c40c8da2/orjson-3.11.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:825e0a85d189533c6bff7e2fc417a28f6fcea53d27125c4551979aecd6c9a197", size = 147270, upload-time = "2026-01-29T15:11:46.782Z" }, - { url = "https://files.pythonhosted.org/packages/b8/2f/5ef8eaf7829dc50da3bf497c7775b21ee88437bc8c41f959aa3504ca6631/orjson-3.11.6-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:b04575417a26530637f6ab4b1f7b4f666eb0433491091da4de38611f97f2fcf3", size = 421222, upload-time = "2026-01-29T15:11:48.106Z" }, - { url = "https://files.pythonhosted.org/packages/3b/b0/dd6b941294c2b5b13da5fdc7e749e58d0c55a5114ab37497155e83050e95/orjson-3.11.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b83eb2e40e8c4da6d6b340ee6b1d6125f5195eb1b0ebb7eac23c6d9d4f92d224", size = 155562, upload-time = "2026-01-29T15:11:49.408Z" }, - { url = "https://files.pythonhosted.org/packages/8e/09/43924331a847476ae2f9a16bd6d3c9dab301265006212ba0d3d7fd58763a/orjson-3.11.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1f42da604ee65a6b87eef858c913ce3e5777872b19321d11e6fc6d21de89b64f", size = 147432, upload-time = "2026-01-29T15:11:50.635Z" }, - { url = "https://files.pythonhosted.org/packages/5d/e9/d9865961081816909f6b49d880749dbbd88425afd7c5bbce0549e2290d77/orjson-3.11.6-cp311-cp311-win32.whl", hash = "sha256:5ae45df804f2d344cffb36c43fdf03c82fb6cd247f5faa41e21891b40dfbf733", size = 139623, upload-time = "2026-01-29T15:11:51.82Z" }, - { url = "https://files.pythonhosted.org/packages/b4/f9/6836edb92f76eec1082919101eb1145d2f9c33c8f2c5e6fa399b82a2aaa8/orjson-3.11.6-cp311-cp311-win_amd64.whl", hash = "sha256:f4295948d65ace0a2d8f2c4ccc429668b7eb8af547578ec882e16bf79b0050b2", size = 136647, upload-time = "2026-01-29T15:11:53.454Z" }, - { url = "https://files.pythonhosted.org/packages/b3/0c/4954082eea948c9ae52ee0bcbaa2f99da3216a71bcc314ab129bde22e565/orjson-3.11.6-cp311-cp311-win_arm64.whl", hash = "sha256:314e9c45e0b81b547e3a1cfa3df3e07a815821b3dac9fe8cb75014071d0c16a4", size = 135327, upload-time = "2026-01-29T15:11:56.616Z" }, - { url = "https://files.pythonhosted.org/packages/14/ba/759f2879f41910b7e5e0cdbd9cf82a4f017c527fb0e972e9869ca7fe4c8e/orjson-3.11.6-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6f03f30cd8953f75f2a439070c743c7336d10ee940da918d71c6f3556af3ddcf", size = 249988, upload-time = "2026-01-29T15:11:58.294Z" }, - { url = "https://files.pythonhosted.org/packages/f0/70/54cecb929e6c8b10104fcf580b0cc7dc551aa193e83787dd6f3daba28bb5/orjson-3.11.6-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:af44baae65ef386ad971469a8557a0673bb042b0b9fd4397becd9c2dfaa02588", size = 134445, upload-time = "2026-01-29T15:11:59.819Z" }, - { url = "https://files.pythonhosted.org/packages/f2/6f/ec0309154457b9ba1ad05f11faa4441f76037152f75e1ac577db3ce7ca96/orjson-3.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c310a48542094e4f7dbb6ac076880994986dda8ca9186a58c3cb70a3514d3231", size = 137708, upload-time = "2026-01-29T15:12:01.488Z" }, - { url = "https://files.pythonhosted.org/packages/20/52/3c71b80840f8bab9cb26417302707b7716b7d25f863f3a541bcfa232fe6e/orjson-3.11.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d8dfa7a5d387f15ecad94cb6b2d2d5f4aeea64efd8d526bfc03c9812d01e1cc0", size = 134798, upload-time = "2026-01-29T15:12:02.705Z" }, - { url = "https://files.pythonhosted.org/packages/30/51/b490a43b22ff736282360bd02e6bded455cf31dfc3224e01cd39f919bbd2/orjson-3.11.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba8daee3e999411b50f8b50dbb0a3071dd1845f3f9a1a0a6fa6de86d1689d84d", size = 140839, upload-time = "2026-01-29T15:12:03.956Z" }, - { url = "https://files.pythonhosted.org/packages/95/bc/4bcfe4280c1bc63c5291bb96f98298845b6355da2226d3400e17e7b51e53/orjson-3.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f89d104c974eafd7436d7a5fdbc57f7a1e776789959a2f4f1b2eab5c62a339f4", size = 144080, upload-time = "2026-01-29T15:12:05.151Z" }, - { url = "https://files.pythonhosted.org/packages/01/74/22970f9ead9ab1f1b5f8c227a6c3aa8d71cd2c5acd005868a1d44f2362fa/orjson-3.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2e2e2456788ca5ea75616c40da06fc885a7dc0389780e8a41bf7c5389ba257b", size = 142435, upload-time = "2026-01-29T15:12:06.641Z" }, - { url = "https://files.pythonhosted.org/packages/29/34/d564aff85847ab92c82ee43a7a203683566c2fca0723a5f50aebbe759603/orjson-3.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a42efebc45afabb1448001e90458c4020d5c64fbac8a8dc4045b777db76cb5a", size = 145631, upload-time = "2026-01-29T15:12:08.351Z" }, - { url = "https://files.pythonhosted.org/packages/e7/ef/016957a3890752c4aa2368326ea69fa53cdc1fdae0a94a542b6410dbdf52/orjson-3.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:71b7cbef8471324966c3738c90ba38775563ef01b512feb5ad4805682188d1b9", size = 147058, upload-time = "2026-01-29T15:12:10.023Z" }, - { url = "https://files.pythonhosted.org/packages/56/cc/9a899c3972085645b3225569f91a30e221f441e5dc8126e6d060b971c252/orjson-3.11.6-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:f8515e5910f454fe9a8e13c2bb9dc4bae4c1836313e967e72eb8a4ad874f0248", size = 421161, upload-time = "2026-01-29T15:12:11.308Z" }, - { url = "https://files.pythonhosted.org/packages/21/a8/767d3fbd6d9b8fdee76974db40619399355fd49bf91a6dd2c4b6909ccf05/orjson-3.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:300360edf27c8c9bf7047345a94fddf3a8b8922df0ff69d71d854a170cb375cf", size = 155757, upload-time = "2026-01-29T15:12:12.776Z" }, - { url = "https://files.pythonhosted.org/packages/ad/0b/205cd69ac87e2272e13ef3f5f03a3d4657e317e38c1b08aaa2ef97060bbc/orjson-3.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:caaed4dad39e271adfadc106fab634d173b2bb23d9cf7e67bd645f879175ebfc", size = 147446, upload-time = "2026-01-29T15:12:14.166Z" }, - { url = "https://files.pythonhosted.org/packages/de/c5/dd9f22aa9f27c54c7d05cc32f4580c9ac9b6f13811eeb81d6c4c3f50d6b1/orjson-3.11.6-cp312-cp312-win32.whl", hash = "sha256:955368c11808c89793e847830e1b1007503a5923ddadc108547d3b77df761044", size = 139717, upload-time = "2026-01-29T15:12:15.7Z" }, - { url = "https://files.pythonhosted.org/packages/23/a1/e62fc50d904486970315a1654b8cfb5832eb46abb18cd5405118e7e1fc79/orjson-3.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:2c68de30131481150073d90a5d227a4a421982f42c025ecdfb66157f9579e06f", size = 136711, upload-time = "2026-01-29T15:12:17.055Z" }, - { url = "https://files.pythonhosted.org/packages/04/3d/b4fefad8bdf91e0fe212eb04975aeb36ea92997269d68857efcc7eb1dda3/orjson-3.11.6-cp312-cp312-win_arm64.whl", hash = "sha256:65dfa096f4e3a5e02834b681f539a87fbe85adc82001383c0db907557f666bfc", size = 135212, upload-time = "2026-01-29T15:12:18.3Z" }, - { url = "https://files.pythonhosted.org/packages/ae/45/d9c71c8c321277bc1ceebf599bc55ba826ae538b7c61f287e9a7e71bd589/orjson-3.11.6-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e4ae1670caabb598a88d385798692ce2a1b2f078971b3329cfb85253c6097f5b", size = 249828, upload-time = "2026-01-29T15:12:20.14Z" }, - { url = "https://files.pythonhosted.org/packages/ac/7e/4afcf4cfa9c2f93846d70eee9c53c3c0123286edcbeb530b7e9bd2aea1b2/orjson-3.11.6-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:2c6b81f47b13dac2caa5d20fbc953c75eb802543abf48403a4703ed3bff225f0", size = 134339, upload-time = "2026-01-29T15:12:22.01Z" }, - { url = "https://files.pythonhosted.org/packages/40/10/6d2b8a064c8d2411d3d0ea6ab43125fae70152aef6bea77bb50fa54d4097/orjson-3.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:647d6d034e463764e86670644bdcaf8e68b076e6e74783383b01085ae9ab334f", size = 137662, upload-time = "2026-01-29T15:12:23.307Z" }, - { url = "https://files.pythonhosted.org/packages/5a/50/5804ea7d586baf83ee88969eefda97a24f9a5bdba0727f73e16305175b26/orjson-3.11.6-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8523b9cc4ef174ae52414f7699e95ee657c16aa18b3c3c285d48d7966cce9081", size = 134626, upload-time = "2026-01-29T15:12:25.099Z" }, - { url = "https://files.pythonhosted.org/packages/9e/2e/f0492ed43e376722bb4afd648e06cc1e627fc7ec8ff55f6ee739277813ea/orjson-3.11.6-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:313dfd7184cde50c733fc0d5c8c0e2f09017b573afd11dc36bd7476b30b4cb17", size = 140873, upload-time = "2026-01-29T15:12:26.369Z" }, - { url = "https://files.pythonhosted.org/packages/10/15/6f874857463421794a303a39ac5494786ad46a4ab46d92bda6705d78c5aa/orjson-3.11.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:905ee036064ff1e1fd1fb800055ac477cdcb547a78c22c1bc2bbf8d5d1a6fb42", size = 144044, upload-time = "2026-01-29T15:12:28.082Z" }, - { url = "https://files.pythonhosted.org/packages/d2/c7/b7223a3a70f1d0cc2d86953825de45f33877ee1b124a91ca1f79aa6e643f/orjson-3.11.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ce374cb98411356ba906914441fc993f271a7a666d838d8de0e0900dd4a4bc12", size = 142396, upload-time = "2026-01-29T15:12:30.529Z" }, - { url = "https://files.pythonhosted.org/packages/87/e3/aa1b6d3ad3cd80f10394134f73ae92a1d11fdbe974c34aa199cc18bb5fcf/orjson-3.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cded072b9f65fcfd188aead45efa5bd528ba552add619b3ad2a81f67400ec450", size = 145600, upload-time = "2026-01-29T15:12:31.848Z" }, - { url = "https://files.pythonhosted.org/packages/f6/cf/e4aac5a46cbd39d7e769ef8650efa851dfce22df1ba97ae2b33efe893b12/orjson-3.11.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7ab85bdbc138e1f73a234db6bb2e4cc1f0fcec8f4bd2bd2430e957a01aadf746", size = 146967, upload-time = "2026-01-29T15:12:33.203Z" }, - { url = "https://files.pythonhosted.org/packages/0b/04/975b86a4bcf6cfeda47aad15956d52fbeda280811206e9967380fa9355c8/orjson-3.11.6-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:351b96b614e3c37a27b8ab048239ebc1e0be76cc17481a430d70a77fb95d3844", size = 421003, upload-time = "2026-01-29T15:12:35.097Z" }, - { url = "https://files.pythonhosted.org/packages/28/d1/0369d0baf40eea5ff2300cebfe209883b2473ab4aa4c4974c8bd5ee42bb2/orjson-3.11.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f9959c85576beae5cdcaaf39510b15105f1ee8b70d5dacd90152617f57be8c83", size = 155695, upload-time = "2026-01-29T15:12:36.589Z" }, - { url = "https://files.pythonhosted.org/packages/ab/1f/d10c6d6ae26ff1d7c3eea6fd048280ef2e796d4fb260c5424fd021f68ecf/orjson-3.11.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:75682d62b1b16b61a30716d7a2ec1f4c36195de4a1c61f6665aedd947b93a5d5", size = 147392, upload-time = "2026-01-29T15:12:37.876Z" }, - { url = "https://files.pythonhosted.org/packages/8d/43/7479921c174441a0aa5277c313732e20713c0969ac303be9f03d88d3db5d/orjson-3.11.6-cp313-cp313-win32.whl", hash = "sha256:40dc277999c2ef227dcc13072be879b4cfd325502daeb5c35ed768f706f2bf30", size = 139718, upload-time = "2026-01-29T15:12:39.274Z" }, - { url = "https://files.pythonhosted.org/packages/88/bc/9ffe7dfbf8454bc4e75bb8bf3a405ed9e0598df1d3535bb4adcd46be07d0/orjson-3.11.6-cp313-cp313-win_amd64.whl", hash = "sha256:f0f6e9f8ff7905660bc3c8a54cd4a675aa98f7f175cf00a59815e2ff42c0d916", size = 136635, upload-time = "2026-01-29T15:12:40.593Z" }, - { url = "https://files.pythonhosted.org/packages/6f/7e/51fa90b451470447ea5023b20d83331ec741ae28d1e6d8ed547c24e7de14/orjson-3.11.6-cp313-cp313-win_arm64.whl", hash = "sha256:1608999478664de848e5900ce41f25c4ecdfc4beacbc632b6fd55e1a586e5d38", size = 135175, upload-time = "2026-01-29T15:12:41.997Z" }, - { url = "https://files.pythonhosted.org/packages/31/9f/46ca908abaeeec7560638ff20276ab327b980d73b3cc2f5b205b4a1c60b3/orjson-3.11.6-cp314-cp314-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6026db2692041d2a23fe2545606df591687787825ad5821971ef0974f2c47630", size = 249823, upload-time = "2026-01-29T15:12:43.332Z" }, - { url = "https://files.pythonhosted.org/packages/ff/78/ca478089818d18c9cd04f79c43f74ddd031b63c70fa2a946eb5e85414623/orjson-3.11.6-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:132b0ab2e20c73afa85cf142e547511feb3d2f5b7943468984658f3952b467d4", size = 134328, upload-time = "2026-01-29T15:12:45.171Z" }, - { url = "https://files.pythonhosted.org/packages/39/5e/cbb9d830ed4e47f4375ad8eef8e4fff1bf1328437732c3809054fc4e80be/orjson-3.11.6-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b376fb05f20a96ec117d47987dd3b39265c635725bda40661b4c5b73b77b5fde", size = 137651, upload-time = "2026-01-29T15:12:46.602Z" }, - { url = "https://files.pythonhosted.org/packages/7c/3a/35df6558c5bc3a65ce0961aefee7f8364e59af78749fc796ea255bfa0cf5/orjson-3.11.6-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:954dae4e080574672a1dfcf2a840eddef0f27bd89b0e94903dd0824e9c1db060", size = 134596, upload-time = "2026-01-29T15:12:47.95Z" }, - { url = "https://files.pythonhosted.org/packages/cd/8e/3d32dd7b7f26a19cc4512d6ed0ae3429567c71feef720fe699ff43c5bc9e/orjson-3.11.6-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe515bb89d59e1e4b48637a964f480b35c0a2676de24e65e55310f6016cca7ce", size = 140923, upload-time = "2026-01-29T15:12:49.333Z" }, - { url = "https://files.pythonhosted.org/packages/6c/9c/1efbf5c99b3304f25d6f0d493a8d1492ee98693637c10ce65d57be839d7b/orjson-3.11.6-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:380f9709c275917af28feb086813923251e11ee10687257cd7f1ea188bcd4485", size = 144068, upload-time = "2026-01-29T15:12:50.927Z" }, - { url = "https://files.pythonhosted.org/packages/82/83/0d19eeb5be797de217303bbb55dde58dba26f996ed905d301d98fd2d4637/orjson-3.11.6-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8173e0d3f6081e7034c51cf984036d02f6bab2a2126de5a759d79f8e5a140e7", size = 142493, upload-time = "2026-01-29T15:12:52.432Z" }, - { url = "https://files.pythonhosted.org/packages/32/a7/573fec3df4dc8fc259b7770dc6c0656f91adce6e19330c78d23f87945d1e/orjson-3.11.6-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dddf9ba706294906c56ef5150a958317b09aa3a8a48df1c52ccf22ec1907eac", size = 145616, upload-time = "2026-01-29T15:12:53.903Z" }, - { url = "https://files.pythonhosted.org/packages/c2/0e/23551b16f21690f7fd5122e3cf40fdca5d77052a434d0071990f97f5fe2f/orjson-3.11.6-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:cbae5c34588dc79938dffb0b6fbe8c531f4dc8a6ad7f39759a9eb5d2da405ef2", size = 146951, upload-time = "2026-01-29T15:12:55.698Z" }, - { url = "https://files.pythonhosted.org/packages/b8/63/5e6c8f39805c39123a18e412434ea364349ee0012548d08aa586e2bd6aa9/orjson-3.11.6-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:f75c318640acbddc419733b57f8a07515e587a939d8f54363654041fd1f4e465", size = 421024, upload-time = "2026-01-29T15:12:57.434Z" }, - { url = "https://files.pythonhosted.org/packages/1d/4d/724975cf0087f6550bd01fd62203418afc0ea33fd099aed318c5bcc52df8/orjson-3.11.6-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:e0ab8d13aa2a3e98b4a43487c9205b2c92c38c054b4237777484d503357c8437", size = 155774, upload-time = "2026-01-29T15:12:59.397Z" }, - { url = "https://files.pythonhosted.org/packages/a8/a3/f4c4e3f46b55db29e0a5f20493b924fc791092d9a03ff2068c9fe6c1002f/orjson-3.11.6-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:f884c7fb1020d44612bd7ac0db0babba0e2f78b68d9a650c7959bf99c783773f", size = 147393, upload-time = "2026-01-29T15:13:00.769Z" }, - { url = "https://files.pythonhosted.org/packages/ee/86/6f5529dd27230966171ee126cecb237ed08e9f05f6102bfaf63e5b32277d/orjson-3.11.6-cp314-cp314-win32.whl", hash = "sha256:8d1035d1b25732ec9f971e833a3e299d2b1a330236f75e6fd945ad982c76aaf3", size = 139760, upload-time = "2026-01-29T15:13:02.173Z" }, - { url = "https://files.pythonhosted.org/packages/d3/b5/91ae7037b2894a6b5002fb33f4fbccec98424a928469835c3837fbb22a9b/orjson-3.11.6-cp314-cp314-win_amd64.whl", hash = "sha256:931607a8865d21682bb72de54231655c86df1870502d2962dbfd12c82890d077", size = 136633, upload-time = "2026-01-29T15:13:04.267Z" }, - { url = "https://files.pythonhosted.org/packages/55/74/f473a3ec7a0a7ebc825ca8e3c86763f7d039f379860c81ba12dcdd456547/orjson-3.11.6-cp314-cp314-win_arm64.whl", hash = "sha256:fe71f6b283f4f1832204ab8235ce07adad145052614f77c876fcf0dac97bc06f", size = 135168, upload-time = "2026-01-29T15:13:05.932Z" }, +version = "3.11.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/53/45/b268004f745ede84e5798b48ee12b05129d19235d0e15267aa57dcdb400b/orjson-3.11.7.tar.gz", hash = "sha256:9b1a67243945819ce55d24a30b59d6a168e86220452d2c96f4d1f093e71c0c49", size = 6144992, upload-time = "2026-02-02T15:38:49.29Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/1a/a373746fa6d0e116dd9e54371a7b54622c44d12296d5d0f3ad5e3ff33490/orjson-3.11.7-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:a02c833f38f36546ba65a452127633afce4cf0dd7296b753d3bb54e55e5c0174", size = 229140, upload-time = "2026-02-02T15:37:06.082Z" }, + { url = "https://files.pythonhosted.org/packages/52/a2/fa129e749d500f9b183e8a3446a193818a25f60261e9ce143ad61e975208/orjson-3.11.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b63c6e6738d7c3470ad01601e23376aa511e50e1f3931395b9f9c722406d1a67", size = 128670, upload-time = "2026-02-02T15:37:08.002Z" }, + { url = "https://files.pythonhosted.org/packages/08/93/1e82011cd1e0bd051ef9d35bed1aa7fb4ea1f0a055dc2c841b46b43a9ebd/orjson-3.11.7-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:043d3006b7d32c7e233b8cfb1f01c651013ea079e08dcef7189a29abd8befe11", size = 123832, upload-time = "2026-02-02T15:37:09.191Z" }, + { url = "https://files.pythonhosted.org/packages/fe/d8/a26b431ef962c7d55736674dddade876822f3e33223c1f47a36879350d04/orjson-3.11.7-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57036b27ac8a25d81112eb0cc9835cd4833c5b16e1467816adc0015f59e870dc", size = 129171, upload-time = "2026-02-02T15:37:11.112Z" }, + { url = "https://files.pythonhosted.org/packages/a7/19/f47819b84a580f490da260c3ee9ade214cf4cf78ac9ce8c1c758f80fdfc9/orjson-3.11.7-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:733ae23ada68b804b222c44affed76b39e30806d38660bf1eb200520d259cc16", size = 141967, upload-time = "2026-02-02T15:37:12.282Z" }, + { url = "https://files.pythonhosted.org/packages/5b/cd/37ece39a0777ba077fdcdbe4cccae3be8ed00290c14bf8afdc548befc260/orjson-3.11.7-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5fdfad2093bdd08245f2e204d977facd5f871c88c4a71230d5bcbd0e43bf6222", size = 130991, upload-time = "2026-02-02T15:37:13.465Z" }, + { url = "https://files.pythonhosted.org/packages/8f/ed/f2b5d66aa9b6b5c02ff5f120efc7b38c7c4962b21e6be0f00fd99a5c348e/orjson-3.11.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cededd6738e1c153530793998e31c05086582b08315db48ab66649768f326baa", size = 133674, upload-time = "2026-02-02T15:37:14.694Z" }, + { url = "https://files.pythonhosted.org/packages/c4/6e/baa83e68d1aa09fa8c3e5b2c087d01d0a0bd45256de719ed7bc22c07052d/orjson-3.11.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:14f440c7268c8f8633d1b3d443a434bd70cb15686117ea6beff8fdc8f5917a1e", size = 138722, upload-time = "2026-02-02T15:37:16.501Z" }, + { url = "https://files.pythonhosted.org/packages/0c/47/7f8ef4963b772cd56999b535e553f7eb5cd27e9dd6c049baee6f18bfa05d/orjson-3.11.7-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:3a2479753bbb95b0ebcf7969f562cdb9668e6d12416a35b0dda79febf89cdea2", size = 409056, upload-time = "2026-02-02T15:37:17.895Z" }, + { url = "https://files.pythonhosted.org/packages/38/eb/2df104dd2244b3618f25325a656f85cc3277f74bbd91224752410a78f3c7/orjson-3.11.7-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:71924496986275a737f38e3f22b4e0878882b3f7a310d2ff4dc96e812789120c", size = 144196, upload-time = "2026-02-02T15:37:19.349Z" }, + { url = "https://files.pythonhosted.org/packages/b6/2a/ee41de0aa3a6686598661eae2b4ebdff1340c65bfb17fcff8b87138aab21/orjson-3.11.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b4a9eefdc70bf8bf9857f0290f973dec534ac84c35cd6a7f4083be43e7170a8f", size = 134979, upload-time = "2026-02-02T15:37:20.906Z" }, + { url = "https://files.pythonhosted.org/packages/4c/fa/92fc5d3d402b87a8b28277a9ed35386218a6a5287c7fe5ee9b9f02c53fb2/orjson-3.11.7-cp310-cp310-win32.whl", hash = "sha256:ae9e0b37a834cef7ce8f99de6498f8fad4a2c0bf6bfc3d02abd8ed56aa15b2de", size = 127968, upload-time = "2026-02-02T15:37:23.178Z" }, + { url = "https://files.pythonhosted.org/packages/07/29/a576bf36d73d60df06904d3844a9df08e25d59eba64363aaf8ec2f9bff41/orjson-3.11.7-cp310-cp310-win_amd64.whl", hash = "sha256:d772afdb22555f0c58cfc741bdae44180122b3616faa1ecadb595cd526e4c993", size = 125128, upload-time = "2026-02-02T15:37:24.329Z" }, + { url = "https://files.pythonhosted.org/packages/37/02/da6cb01fc6087048d7f61522c327edf4250f1683a58a839fdcc435746dd5/orjson-3.11.7-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9487abc2c2086e7c8eb9a211d2ce8855bae0e92586279d0d27b341d5ad76c85c", size = 228664, upload-time = "2026-02-02T15:37:25.542Z" }, + { url = "https://files.pythonhosted.org/packages/c1/c2/5885e7a5881dba9a9af51bc564e8967225a642b3e03d089289a35054e749/orjson-3.11.7-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:79cacb0b52f6004caf92405a7e1f11e6e2de8bdf9019e4f76b44ba045125cd6b", size = 125344, upload-time = "2026-02-02T15:37:26.92Z" }, + { url = "https://files.pythonhosted.org/packages/a4/1d/4e7688de0a92d1caf600dfd5fb70b4c5bfff51dfa61ac555072ef2d0d32a/orjson-3.11.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2e85fe4698b6a56d5e2ebf7ae87544d668eb6bde1ad1226c13f44663f20ec9e", size = 128404, upload-time = "2026-02-02T15:37:28.108Z" }, + { url = "https://files.pythonhosted.org/packages/2f/b2/ec04b74ae03a125db7bd69cffd014b227b7f341e3261bf75b5eb88a1aa92/orjson-3.11.7-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b8d14b71c0b12963fe8a62aac87119f1afdf4cb88a400f61ca5ae581449efcb5", size = 123677, upload-time = "2026-02-02T15:37:30.287Z" }, + { url = "https://files.pythonhosted.org/packages/4c/69/f95bdf960605f08f827f6e3291fe243d8aa9c5c9ff017a8d7232209184c3/orjson-3.11.7-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:91c81ef070c8f3220054115e1ef468b1c9ce8497b4e526cb9f68ab4dc0a7ac62", size = 128950, upload-time = "2026-02-02T15:37:31.595Z" }, + { url = "https://files.pythonhosted.org/packages/a4/1b/de59c57bae1d148ef298852abd31909ac3089cff370dfd4cd84cc99cbc42/orjson-3.11.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:411ebaf34d735e25e358a6d9e7978954a9c9d58cfb47bc6683cdc3964cd2f910", size = 141756, upload-time = "2026-02-02T15:37:32.985Z" }, + { url = "https://files.pythonhosted.org/packages/ee/9e/9decc59f4499f695f65c650f6cfa6cd4c37a3fbe8fa235a0a3614cb54386/orjson-3.11.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a16bcd08ab0bcdfc7e8801d9c4a9cc17e58418e4d48ddc6ded4e9e4b1a94062b", size = 130812, upload-time = "2026-02-02T15:37:34.204Z" }, + { url = "https://files.pythonhosted.org/packages/28/e6/59f932bcabd1eac44e334fe8e3281a92eacfcb450586e1f4bde0423728d8/orjson-3.11.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c0b51672e466fd7e56230ffbae7f1639e18d0ce023351fb75da21b71bc2c960", size = 133444, upload-time = "2026-02-02T15:37:35.446Z" }, + { url = "https://files.pythonhosted.org/packages/f1/36/b0f05c0eaa7ca30bc965e37e6a2956b0d67adb87a9872942d3568da846ae/orjson-3.11.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:136dcd6a2e796dfd9ffca9fc027d778567b0b7c9968d092842d3c323cef88aa8", size = 138609, upload-time = "2026-02-02T15:37:36.657Z" }, + { url = "https://files.pythonhosted.org/packages/b8/03/58ec7d302b8d86944c60c7b4b82975d5161fcce4c9bc8c6cb1d6741b6115/orjson-3.11.7-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:7ba61079379b0ae29e117db13bda5f28d939766e410d321ec1624afc6a0b0504", size = 408918, upload-time = "2026-02-02T15:37:38.076Z" }, + { url = "https://files.pythonhosted.org/packages/06/3a/868d65ef9a8b99be723bd510de491349618abd9f62c826cf206d962db295/orjson-3.11.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:0527a4510c300e3b406591b0ba69b5dc50031895b0a93743526a3fc45f59d26e", size = 143998, upload-time = "2026-02-02T15:37:39.706Z" }, + { url = "https://files.pythonhosted.org/packages/5b/c7/1e18e1c83afe3349f4f6dc9e14910f0ae5f82eac756d1412ea4018938535/orjson-3.11.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a709e881723c9b18acddcfb8ba357322491ad553e277cf467e1e7e20e2d90561", size = 134802, upload-time = "2026-02-02T15:37:41.002Z" }, + { url = "https://files.pythonhosted.org/packages/d4/0b/ccb7ee1a65b37e8eeb8b267dc953561d72370e85185e459616d4345bab34/orjson-3.11.7-cp311-cp311-win32.whl", hash = "sha256:c43b8b5bab288b6b90dac410cca7e986a4fa747a2e8f94615aea407da706980d", size = 127828, upload-time = "2026-02-02T15:37:42.241Z" }, + { url = "https://files.pythonhosted.org/packages/af/9e/55c776dffda3f381e0f07d010a4f5f3902bf48eaba1bb7684d301acd4924/orjson-3.11.7-cp311-cp311-win_amd64.whl", hash = "sha256:6543001328aa857187f905308a028935864aefe9968af3848401b6fe80dbb471", size = 124941, upload-time = "2026-02-02T15:37:43.444Z" }, + { url = "https://files.pythonhosted.org/packages/aa/8e/424a620fa7d263b880162505fb107ef5e0afaa765b5b06a88312ac291560/orjson-3.11.7-cp311-cp311-win_arm64.whl", hash = "sha256:1ee5cc7160a821dfe14f130bc8e63e7611051f964b463d9e2a3a573204446a4d", size = 126245, upload-time = "2026-02-02T15:37:45.18Z" }, + { url = "https://files.pythonhosted.org/packages/80/bf/76f4f1665f6983385938f0e2a5d7efa12a58171b8456c252f3bae8a4cf75/orjson-3.11.7-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:bd03ea7606833655048dab1a00734a2875e3e86c276e1d772b2a02556f0d895f", size = 228545, upload-time = "2026-02-02T15:37:46.376Z" }, + { url = "https://files.pythonhosted.org/packages/79/53/6c72c002cb13b5a978a068add59b25a8bdf2800ac1c9c8ecdb26d6d97064/orjson-3.11.7-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:89e440ebc74ce8ab5c7bc4ce6757b4a6b1041becb127df818f6997b5c71aa60b", size = 125224, upload-time = "2026-02-02T15:37:47.697Z" }, + { url = "https://files.pythonhosted.org/packages/2c/83/10e48852865e5dd151bdfe652c06f7da484578ed02c5fca938e3632cb0b8/orjson-3.11.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ede977b5fe5ac91b1dffc0a517ca4542d2ec8a6a4ff7b2652d94f640796342a", size = 128154, upload-time = "2026-02-02T15:37:48.954Z" }, + { url = "https://files.pythonhosted.org/packages/6e/52/a66e22a2b9abaa374b4a081d410edab6d1e30024707b87eab7c734afe28d/orjson-3.11.7-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b7b1dae39230a393df353827c855a5f176271c23434cfd2db74e0e424e693e10", size = 123548, upload-time = "2026-02-02T15:37:50.187Z" }, + { url = "https://files.pythonhosted.org/packages/de/38/605d371417021359f4910c496f764c48ceb8997605f8c25bf1dfe58c0ebe/orjson-3.11.7-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed46f17096e28fb28d2975834836a639af7278aa87c84f68ab08fbe5b8bd75fa", size = 129000, upload-time = "2026-02-02T15:37:51.426Z" }, + { url = "https://files.pythonhosted.org/packages/44/98/af32e842b0ffd2335c89714d48ca4e3917b42f5d6ee5537832e069a4b3ac/orjson-3.11.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3726be79e36e526e3d9c1aceaadbfb4a04ee80a72ab47b3f3c17fefb9812e7b8", size = 141686, upload-time = "2026-02-02T15:37:52.607Z" }, + { url = "https://files.pythonhosted.org/packages/96/0b/fc793858dfa54be6feee940c1463370ece34b3c39c1ca0aa3845f5ba9892/orjson-3.11.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0724e265bc548af1dedebd9cb3d24b4e1c1e685a343be43e87ba922a5c5fff2f", size = 130812, upload-time = "2026-02-02T15:37:53.944Z" }, + { url = "https://files.pythonhosted.org/packages/dc/91/98a52415059db3f374757d0b7f0f16e3b5cd5976c90d1c2b56acaea039e6/orjson-3.11.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7745312efa9e11c17fbd3cb3097262d079da26930ae9ae7ba28fb738367cbad", size = 133440, upload-time = "2026-02-02T15:37:55.615Z" }, + { url = "https://files.pythonhosted.org/packages/dc/b6/cb540117bda61791f46381f8c26c8f93e802892830a6055748d3bb1925ab/orjson-3.11.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f904c24bdeabd4298f7a977ef14ca2a022ca921ed670b92ecd16ab6f3d01f867", size = 138386, upload-time = "2026-02-02T15:37:56.814Z" }, + { url = "https://files.pythonhosted.org/packages/63/1a/50a3201c334a7f17c231eee5f841342190723794e3b06293f26e7cf87d31/orjson-3.11.7-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:b9fc4d0f81f394689e0814617aadc4f2ea0e8025f38c226cbf22d3b5ddbf025d", size = 408853, upload-time = "2026-02-02T15:37:58.291Z" }, + { url = "https://files.pythonhosted.org/packages/87/cd/8de1c67d0be44fdc22701e5989c0d015a2adf391498ad42c4dc589cd3013/orjson-3.11.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:849e38203e5be40b776ed2718e587faf204d184fc9a008ae441f9442320c0cab", size = 144130, upload-time = "2026-02-02T15:38:00.163Z" }, + { url = "https://files.pythonhosted.org/packages/0f/fe/d605d700c35dd55f51710d159fc54516a280923cd1b7e47508982fbb387d/orjson-3.11.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4682d1db3bcebd2b64757e0ddf9e87ae5f00d29d16c5cdf3a62f561d08cc3dd2", size = 134818, upload-time = "2026-02-02T15:38:01.507Z" }, + { url = "https://files.pythonhosted.org/packages/e4/e4/15ecc67edb3ddb3e2f46ae04475f2d294e8b60c1825fbe28a428b93b3fbd/orjson-3.11.7-cp312-cp312-win32.whl", hash = "sha256:f4f7c956b5215d949a1f65334cf9d7612dde38f20a95f2315deef167def91a6f", size = 127923, upload-time = "2026-02-02T15:38:02.75Z" }, + { url = "https://files.pythonhosted.org/packages/34/70/2e0855361f76198a3965273048c8e50a9695d88cd75811a5b46444895845/orjson-3.11.7-cp312-cp312-win_amd64.whl", hash = "sha256:bf742e149121dc5648ba0a08ea0871e87b660467ef168a3a5e53bc1fbd64bb74", size = 125007, upload-time = "2026-02-02T15:38:04.032Z" }, + { url = "https://files.pythonhosted.org/packages/68/40/c2051bd19fc467610fed469dc29e43ac65891571138f476834ca192bc290/orjson-3.11.7-cp312-cp312-win_arm64.whl", hash = "sha256:26c3b9132f783b7d7903bf1efb095fed8d4a3a85ec0d334ee8beff3d7a4749d5", size = 126089, upload-time = "2026-02-02T15:38:05.297Z" }, + { url = "https://files.pythonhosted.org/packages/89/25/6e0e52cac5aab51d7b6dcd257e855e1dec1c2060f6b28566c509b4665f62/orjson-3.11.7-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:1d98b30cc1313d52d4af17d9c3d307b08389752ec5f2e5febdfada70b0f8c733", size = 228390, upload-time = "2026-02-02T15:38:06.8Z" }, + { url = "https://files.pythonhosted.org/packages/a5/29/a77f48d2fc8a05bbc529e5ff481fb43d914f9e383ea2469d4f3d51df3d00/orjson-3.11.7-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:d897e81f8d0cbd2abb82226d1860ad2e1ab3ff16d7b08c96ca00df9d45409ef4", size = 125189, upload-time = "2026-02-02T15:38:08.181Z" }, + { url = "https://files.pythonhosted.org/packages/89/25/0a16e0729a0e6a1504f9d1a13cdd365f030068aab64cec6958396b9969d7/orjson-3.11.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:814be4b49b228cfc0b3c565acf642dd7d13538f966e3ccde61f4f55be3e20785", size = 128106, upload-time = "2026-02-02T15:38:09.41Z" }, + { url = "https://files.pythonhosted.org/packages/66/da/a2e505469d60666a05ab373f1a6322eb671cb2ba3a0ccfc7d4bc97196787/orjson-3.11.7-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d06e5c5fed5caedd2e540d62e5b1c25e8c82431b9e577c33537e5fa4aa909539", size = 123363, upload-time = "2026-02-02T15:38:10.73Z" }, + { url = "https://files.pythonhosted.org/packages/23/bf/ed73f88396ea35c71b38961734ea4a4746f7ca0768bf28fd551d37e48dd0/orjson-3.11.7-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:31c80ce534ac4ea3739c5ee751270646cbc46e45aea7576a38ffec040b4029a1", size = 129007, upload-time = "2026-02-02T15:38:12.138Z" }, + { url = "https://files.pythonhosted.org/packages/73/3c/b05d80716f0225fc9008fbf8ab22841dcc268a626aa550561743714ce3bf/orjson-3.11.7-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f50979824bde13d32b4320eedd513431c921102796d86be3eee0b58e58a3ecd1", size = 141667, upload-time = "2026-02-02T15:38:13.398Z" }, + { url = "https://files.pythonhosted.org/packages/61/e8/0be9b0addd9bf86abfc938e97441dcd0375d494594b1c8ad10fe57479617/orjson-3.11.7-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e54f3808e2b6b945078c41aa8d9b5834b28c50843846e97807e5adb75fa9705", size = 130832, upload-time = "2026-02-02T15:38:14.698Z" }, + { url = "https://files.pythonhosted.org/packages/c9/ec/c68e3b9021a31d9ec15a94931db1410136af862955854ed5dd7e7e4f5bff/orjson-3.11.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a12b80df61aab7b98b490fe9e4879925ba666fccdfcd175252ce4d9035865ace", size = 133373, upload-time = "2026-02-02T15:38:16.109Z" }, + { url = "https://files.pythonhosted.org/packages/d2/45/f3466739aaafa570cc8e77c6dbb853c48bf56e3b43738020e2661e08b0ac/orjson-3.11.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:996b65230271f1a97026fd0e6a753f51fbc0c335d2ad0c6201f711b0da32693b", size = 138307, upload-time = "2026-02-02T15:38:17.453Z" }, + { url = "https://files.pythonhosted.org/packages/e1/84/9f7f02288da1ffb31405c1be07657afd1eecbcb4b64ee2817b6fe0f785fa/orjson-3.11.7-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:ab49d4b2a6a1d415ddb9f37a21e02e0d5dbfe10b7870b21bf779fc21e9156157", size = 408695, upload-time = "2026-02-02T15:38:18.831Z" }, + { url = "https://files.pythonhosted.org/packages/18/07/9dd2f0c0104f1a0295ffbe912bc8d63307a539b900dd9e2c48ef7810d971/orjson-3.11.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:390a1dce0c055ddf8adb6aa94a73b45a4a7d7177b5c584b8d1c1947f2ba60fb3", size = 144099, upload-time = "2026-02-02T15:38:20.28Z" }, + { url = "https://files.pythonhosted.org/packages/a5/66/857a8e4a3292e1f7b1b202883bcdeb43a91566cf59a93f97c53b44bd6801/orjson-3.11.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1eb80451a9c351a71dfaf5b7ccc13ad065405217726b59fdbeadbcc544f9d223", size = 134806, upload-time = "2026-02-02T15:38:22.186Z" }, + { url = "https://files.pythonhosted.org/packages/0a/5b/6ebcf3defc1aab3a338ca777214966851e92efb1f30dc7fc8285216e6d1b/orjson-3.11.7-cp313-cp313-win32.whl", hash = "sha256:7477aa6a6ec6139c5cb1cc7b214643592169a5494d200397c7fc95d740d5fcf3", size = 127914, upload-time = "2026-02-02T15:38:23.511Z" }, + { url = "https://files.pythonhosted.org/packages/00/04/c6f72daca5092e3117840a1b1e88dfc809cc1470cf0734890d0366b684a1/orjson-3.11.7-cp313-cp313-win_amd64.whl", hash = "sha256:b9f95dcdea9d4f805daa9ddf02617a89e484c6985fa03055459f90e87d7a0757", size = 124986, upload-time = "2026-02-02T15:38:24.836Z" }, + { url = "https://files.pythonhosted.org/packages/03/ba/077a0f6f1085d6b806937246860fafbd5b17f3919c70ee3f3d8d9c713f38/orjson-3.11.7-cp313-cp313-win_arm64.whl", hash = "sha256:800988273a014a0541483dc81021247d7eacb0c845a9d1a34a422bc718f41539", size = 126045, upload-time = "2026-02-02T15:38:26.216Z" }, + { url = "https://files.pythonhosted.org/packages/e9/1e/745565dca749813db9a093c5ebc4bac1a9475c64d54b95654336ac3ed961/orjson-3.11.7-cp314-cp314-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:de0a37f21d0d364954ad5de1970491d7fbd0fb1ef7417d4d56a36dc01ba0c0a0", size = 228391, upload-time = "2026-02-02T15:38:27.757Z" }, + { url = "https://files.pythonhosted.org/packages/46/19/e40f6225da4d3aa0c8dc6e5219c5e87c2063a560fe0d72a88deb59776794/orjson-3.11.7-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:c2428d358d85e8da9d37cba18b8c4047c55222007a84f97156a5b22028dfbfc0", size = 125188, upload-time = "2026-02-02T15:38:29.241Z" }, + { url = "https://files.pythonhosted.org/packages/9d/7e/c4de2babef2c0817fd1f048fd176aa48c37bec8aef53d2fa932983032cce/orjson-3.11.7-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c4bc6c6ac52cdaa267552544c73e486fecbd710b7ac09bc024d5a78555a22f6", size = 128097, upload-time = "2026-02-02T15:38:30.618Z" }, + { url = "https://files.pythonhosted.org/packages/eb/74/233d360632bafd2197f217eee7fb9c9d0229eac0c18128aee5b35b0014fe/orjson-3.11.7-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bd0d68edd7dfca1b2eca9361a44ac9f24b078de3481003159929a0573f21a6bf", size = 123364, upload-time = "2026-02-02T15:38:32.363Z" }, + { url = "https://files.pythonhosted.org/packages/79/51/af79504981dd31efe20a9e360eb49c15f06df2b40e7f25a0a52d9ae888e8/orjson-3.11.7-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:623ad1b9548ef63886319c16fa317848e465a21513b31a6ad7b57443c3e0dcf5", size = 129076, upload-time = "2026-02-02T15:38:33.68Z" }, + { url = "https://files.pythonhosted.org/packages/67/e2/da898eb68b72304f8de05ca6715870d09d603ee98d30a27e8a9629abc64b/orjson-3.11.7-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6e776b998ac37c0396093d10290e60283f59cfe0fc3fccbd0ccc4bd04dd19892", size = 141705, upload-time = "2026-02-02T15:38:34.989Z" }, + { url = "https://files.pythonhosted.org/packages/c5/89/15364d92acb3d903b029e28d834edb8780c2b97404cbf7929aa6b9abdb24/orjson-3.11.7-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:652c6c3af76716f4a9c290371ba2e390ede06f6603edb277b481daf37f6f464e", size = 130855, upload-time = "2026-02-02T15:38:36.379Z" }, + { url = "https://files.pythonhosted.org/packages/c2/8b/ecdad52d0b38d4b8f514be603e69ccd5eacf4e7241f972e37e79792212ec/orjson-3.11.7-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a56df3239294ea5964adf074c54bcc4f0ccd21636049a2cf3ca9cf03b5d03cf1", size = 133386, upload-time = "2026-02-02T15:38:37.704Z" }, + { url = "https://files.pythonhosted.org/packages/b9/0e/45e1dcf10e17d0924b7c9162f87ec7b4ca79e28a0548acf6a71788d3e108/orjson-3.11.7-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:bda117c4148e81f746655d5a3239ae9bd00cb7bc3ca178b5fc5a5997e9744183", size = 138295, upload-time = "2026-02-02T15:38:39.096Z" }, + { url = "https://files.pythonhosted.org/packages/63/d7/4d2e8b03561257af0450f2845b91fbd111d7e526ccdf737267108075e0ba/orjson-3.11.7-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:23d6c20517a97a9daf1d48b580fcdc6f0516c6f4b5038823426033690b4d2650", size = 408720, upload-time = "2026-02-02T15:38:40.634Z" }, + { url = "https://files.pythonhosted.org/packages/78/cf/d45343518282108b29c12a65892445fc51f9319dc3c552ceb51bb5905ed2/orjson-3.11.7-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:8ff206156006da5b847c9304b6308a01e8cdbc8cce824e2779a5ba71c3def141", size = 144152, upload-time = "2026-02-02T15:38:42.262Z" }, + { url = "https://files.pythonhosted.org/packages/a9/3a/d6001f51a7275aacd342e77b735c71fa04125a3f93c36fee4526bc8c654e/orjson-3.11.7-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:962d046ee1765f74a1da723f4b33e3b228fe3a48bd307acce5021dfefe0e29b2", size = 134814, upload-time = "2026-02-02T15:38:43.627Z" }, + { url = "https://files.pythonhosted.org/packages/1d/d3/f19b47ce16820cc2c480f7f1723e17f6d411b3a295c60c8ad3aa9ff1c96a/orjson-3.11.7-cp314-cp314-win32.whl", hash = "sha256:89e13dd3f89f1c38a9c9eba5fbf7cdc2d1feca82f5f290864b4b7a6aac704576", size = 127997, upload-time = "2026-02-02T15:38:45.06Z" }, + { url = "https://files.pythonhosted.org/packages/12/df/172771902943af54bf661a8d102bdf2e7f932127968080632bda6054b62c/orjson-3.11.7-cp314-cp314-win_amd64.whl", hash = "sha256:845c3e0d8ded9c9271cd79596b9b552448b885b97110f628fb687aee2eed11c1", size = 124985, upload-time = "2026-02-02T15:38:46.388Z" }, + { url = "https://files.pythonhosted.org/packages/6f/1c/f2a8d8a1b17514660a614ce5f7aac74b934e69f5abc2700cc7ced882a009/orjson-3.11.7-cp314-cp314-win_arm64.whl", hash = "sha256:4a2e9c5be347b937a2e0203866f12bba36082e89b402ddb9e927d5822e43088d", size = 126038, upload-time = "2026-02-02T15:38:47.703Z" }, ] [[package]] @@ -2737,15 +2754,15 @@ wheels = [ [[package]] name = "rich" -version = "14.3.1" +version = "14.3.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markdown-it-py" }, { name = "pygments" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a1/84/4831f881aa6ff3c976f6d6809b58cdfa350593ffc0dc3c58f5f6586780fb/rich-14.3.1.tar.gz", hash = "sha256:b8c5f568a3a749f9290ec6bddedf835cec33696bfc1e48bcfecb276c7386e4b8", size = 230125, upload-time = "2026-01-24T21:40:44.847Z" } +sdist = { url = "https://files.pythonhosted.org/packages/74/99/a4cab2acbb884f80e558b0771e97e21e939c5dfb460f488d19df485e8298/rich-14.3.2.tar.gz", hash = "sha256:e712f11c1a562a11843306f5ed999475f09ac31ffb64281f73ab29ffdda8b3b8", size = 230143, upload-time = "2026-02-01T16:20:47.908Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/87/2a/a1810c8627b9ec8c57ec5ec325d306701ae7be50235e8fd81266e002a3cc/rich-14.3.1-py3-none-any.whl", hash = "sha256:da750b1aebbff0b372557426fb3f35ba56de8ef954b3190315eb64076d6fb54e", size = 309952, upload-time = "2026-01-24T21:40:42.969Z" }, + { url = "https://files.pythonhosted.org/packages/ef/45/615f5babd880b4bd7d405cc0dc348234c5ffb6ed1ea33e152ede08b2072d/rich-14.3.2-py3-none-any.whl", hash = "sha256:08e67c3e90884651da3239ea668222d19bea7b589149d8014a21c633420dbb69", size = 309963, upload-time = "2026-02-01T16:20:46.078Z" }, ] [[package]] @@ -3222,6 +3239,8 @@ docs = [ { name = "sphinx-immaterial" }, ] test = [ + { name = "aiomysql" }, + { name = "cryptography" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-codspeed" }, @@ -3282,6 +3301,8 @@ docs = [ { name = "sphinx-immaterial" }, ] test = [ + { name = "aiomysql", specifier = ">=0.3.2" }, + { name = "cryptography", specifier = ">=46.0.3" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-codspeed" },