diff --git a/tests/test_queryset.py b/tests/test_queryset.py index 7b03d4bdc..49da96e85 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -21,6 +21,7 @@ ParamsError, ) from tortoise.expressions import F, RawSQL, Subquery +from tortoise.functions import Length # TODO: Test the many exceptions in QuerySet # TODO: .filter(intnum_null=None) does not work as expected @@ -689,3 +690,83 @@ def test_does_not_exist(self): class TestMultiple(TestNotExist): exp_cls = MultipleObjectsReturned + + +class TestQueryReuse(test.TestCase): + async def test_annotations(self): + a = await Tournament.create(name="A") + + base_query = Tournament.annotate(id_plus_one=F("id") + 1) + query1 = base_query.annotate(id_plus_two=F("id") + 2) + query2 = base_query.annotate(id_plus_three=F("id") + 3) + res = await query1.first() + self.assertEqual(res.id_plus_one, a.id + 1) + self.assertEqual(res.id_plus_two, a.id + 2) + with self.assertRaises(AttributeError): + getattr(res, "id_plus_three") + + res = await query2.first() + self.assertEqual(res.id_plus_one, a.id + 1) + self.assertEqual(res.id_plus_three, a.id + 3) + with self.assertRaises(AttributeError): + getattr(res, "id_plus_two") + + res = await query1.first() + with self.assertRaises(AttributeError): + getattr(res, "id_plus_three") + + async def test_filters(self): + a = await Tournament.create(name="A") + b = await Tournament.create(name="B") + await Tournament.create(name="C") + + base_query = Tournament.exclude(name="C") + tournaments = await base_query + self.assertSetEqual(set(tournaments), {a, b}) + + tournaments = await base_query.exclude(name="A") + self.assertSetEqual(set(tournaments), {b}) + + tournaments = await base_query.exclude(name="B") + self.assertSetEqual(set(tournaments), {a}) + + async def test_joins(self): + tournament_a = await Tournament.create(name="A") + tournament_b = await Tournament.create(name="B") + tournament_c = await Tournament.create(name="C") + event_a = await Event.create(name="A", tournament=tournament_a) + event_b = await Event.create(name="B", tournament=tournament_b) + await Event.create(name="C", tournament=tournament_c) + + base_query = Event.exclude(tournament__name="C") + events = await base_query + self.assertSetEqual(set(events), {event_a, event_b}) + + events = await base_query.exclude(name="A") + self.assertSetEqual(set(events), {event_b}) + + events = await base_query.exclude(name="B") + self.assertSetEqual(set(events), {event_a}) + + async def test_order_by(self): + a = await Tournament.create(name="A") + b = await Tournament.create(name="B") + + base_query = Tournament.all().order_by("name") + tournaments = await base_query + self.assertEqual(tournaments, [a, b]) + + tournaments = await base_query.order_by("-name") + self.assertEqual(tournaments, [b, a]) + + @test.requireCapability(dialect=NotEQ("mssql")) + async def test_values_with_annotations(self): + await Tournament.create(name="Championship") + await Tournament.create(name="Super Bowl") + + base_query = Tournament.annotate(name_length=Length("name")) + tournaments = await base_query.values_list("name") + self.assertListSortEqual(tournaments, [("Championship",), ("Super Bowl",)]) + + tournaments = await base_query.values_list("name_length") + self.assertListSortEqual(tournaments, [(10,), (12,)]) diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 4a80a1ccc..f11109a3e 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -88,13 +88,14 @@ def values( class AwaitableQuery(Generic[MODEL]): __slots__ = ( - "_joined_tables", "query", "model", + "_joined_tables", "_db", "capabilities", "_annotations", "_custom_filters", + "_q_objects", ) def __init__(self, model: Type[MODEL]) -> None: @@ -105,6 +106,7 @@ def __init__(self, model: Type[MODEL]) -> None: self.capabilities: Capabilities = model._meta.db.capabilities self._annotations: Dict[str, Expression] = {} self._custom_filters: Dict[str, FilterInfoDict] = {} + self._q_objects: List[Q] = [] def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient: """ @@ -120,13 +122,7 @@ def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient: db = router.db_for_read(self.model) return db or self.model._meta.db - def resolve_filters( - self, - model: "Type[Model]", - q_objects: List[Q], - annotations: Dict[str, Any], - custom_filters: Dict[str, FilterInfoDict], - ) -> None: + def resolve_filters(self) -> None: """ Builds the common filters for a QuerySet. @@ -135,16 +131,16 @@ def resolve_filters( :param annotations: Extra annotations to add. :param custom_filters: Pre-resolved filters to be passed through. """ - has_aggregate = self._resolve_annotate(annotations, custom_filters) + has_aggregate = self._resolve_annotate(self._annotations, self._custom_filters) modifier = QueryModifier() - for node in q_objects: + for node in self._q_objects: modifier &= node.resolve( ResolveContext( - model=model, - table=model._meta.basetable, - annotations=annotations, - custom_filters=custom_filters, + model=self.model, + table=self.model._meta.basetable, + annotations=self._annotations, + custom_filters=self._custom_filters, ) ) @@ -313,7 +309,6 @@ class QuerySet(AwaitableQuery[MODEL]): "_fields_for_select", "_filter_kwargs", "_orderings", - "_q_objects", "_distinct", "_having", "_group_bys", @@ -338,7 +333,6 @@ def __init__(self, model: Type[MODEL]) -> None: self._offset: Optional[int] = None self._filter_kwargs: Dict[str, Any] = {} self._orderings: List[Tuple[str, Any]] = [] - self._q_objects: List[Q] = [] self._distinct: bool = False self._having: Dict[str, Any] = {} self._fields_for_select: Tuple[str, ...] = () @@ -1031,12 +1025,7 @@ def _make_query(self) -> None: self.resolve_ordering( self.model, self.model._meta.basetable, self._orderings, self._annotations ) - self.resolve_filters( - model=self.model, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - ) + self.resolve_filters() if self._limit is not None: self.query._limit = self._limit if self._offset: @@ -1098,11 +1087,8 @@ async def _execute(self) -> List[MODEL]: class UpdateQuery(AwaitableQuery): __slots__ = ( "update_kwargs", - "q_objects", - "annotations", - "custom_filters", - "orderings", - "limit", + "_orderings", + "_limit", "values", ) @@ -1119,27 +1105,22 @@ def __init__( ) -> None: super().__init__(model) self.update_kwargs = update_kwargs - self.q_objects = q_objects - self.annotations = annotations - self.custom_filters = custom_filters + self._q_objects = q_objects + self._annotations = annotations + self._custom_filters = custom_filters self._db = db - self.limit = limit - self.orderings = orderings + self._limit = limit + self._orderings = orderings self.values: List[Any] = [] def _make_query(self) -> None: table = self.model._meta.basetable self.query = self._db.query_class.update(table) - if self.capabilities.support_update_limit_order_by and self.limit: - self.query._limit = self.limit - self.resolve_ordering(self.model, table, self.orderings, self.annotations) + if self.capabilities.support_update_limit_order_by and self._limit: + self.query._limit = self._limit + self.resolve_ordering(self.model, table, self._orderings, self._annotations) - self.resolve_filters( - model=self.model, - q_objects=self.q_objects, - annotations=self.annotations, - custom_filters=self.custom_filters, - ) + self.resolve_filters() # Need to get executor to get correct column_map executor = self._db.executor_class(model=self.model, db=self._db) count = 0 @@ -1168,8 +1149,8 @@ def _make_query(self) -> None: ResolveContext( model=self.model, table=table, - annotations=self.annotations, - custom_filters=self.custom_filters, + annotations=self._annotations, + custom_filters=self._custom_filters, ) )["field"] else: @@ -1193,11 +1174,10 @@ async def _execute(self) -> int: class DeleteQuery(AwaitableQuery): __slots__ = ( - "q_objects", - "annotations", - "custom_filters", - "orderings", - "limit", + "_annotations", + "_custom_filters", + "_orderings", + "_limit", ) def __init__( @@ -1211,29 +1191,24 @@ def __init__( orderings: List[Tuple[str, str]], ) -> None: super().__init__(model) - self.q_objects = q_objects - self.annotations = annotations - self.custom_filters = custom_filters + self._q_objects = q_objects + self._annotations = annotations + self._custom_filters = custom_filters self._db = db - self.limit = limit - self.orderings = orderings + self._limit = limit + self._orderings = orderings def _make_query(self) -> None: self.query = copy(self.model._meta.basequery) - if self.capabilities.support_update_limit_order_by and self.limit: - self.query._limit = self.limit + if self.capabilities.support_update_limit_order_by and self._limit: + self.query._limit = self._limit self.resolve_ordering( model=self.model, table=self.model._meta.basetable, - orderings=self.orderings, - annotations=self.annotations, + orderings=self._orderings, + annotations=self._annotations, ) - self.resolve_filters( - model=self.model, - q_objects=self.q_objects, - annotations=self.annotations, - custom_filters=self.custom_filters, - ) + self.resolve_filters() self.query._delete_from = True def __await__(self) -> Generator[Any, None, int]: @@ -1248,11 +1223,8 @@ async def _execute(self) -> int: class ExistsQuery(AwaitableQuery): __slots__ = ( - "q_objects", - "annotations", - "custom_filters", - "force_indexes", - "use_indexes", + "_force_indexes", + "_use_indexes", ) def __init__( @@ -1266,30 +1238,25 @@ def __init__( use_indexes: Set[str], ) -> None: super().__init__(model) - self.q_objects = q_objects - self.annotations = annotations - self.custom_filters = custom_filters + self._q_objects = q_objects self._db = db - self.force_indexes = force_indexes - self.use_indexes = use_indexes + self._annotations = annotations + self._custom_filters = custom_filters + self._force_indexes = force_indexes + self._use_indexes = use_indexes def _make_query(self) -> None: self.query = copy(self.model._meta.basequery) - self.resolve_filters( - model=self.model, - q_objects=self.q_objects, - annotations=self.annotations, - custom_filters=self.custom_filters, - ) + self.resolve_filters() self.query._limit = 1 self.query._select_other(ValueWrapper(1)) - if self.force_indexes: + if self._force_indexes: self.query._force_indexes = [] - self.query = self.query.force_index(*self.force_indexes) - if self.use_indexes: + self.query = self.query.force_index(*self._force_indexes) + if self._use_indexes: self.query._use_indexes = [] - self.query = self.query.use_index(*self.use_indexes) + self.query = self.query.use_index(*self._use_indexes) def __await__(self) -> Generator[Any, None, bool]: if self._db is None: @@ -1304,13 +1271,10 @@ async def _execute(self) -> bool: class CountQuery(AwaitableQuery): __slots__ = ( - "q_objects", - "annotations", - "custom_filters", - "limit", - "offset", - "force_indexes", - "use_indexes", + "_limit", + "_offset", + "_force_indexes", + "_use_indexes", ) def __init__( @@ -1326,34 +1290,32 @@ def __init__( use_indexes: Set[str], ) -> None: super().__init__(model) - self.q_objects = q_objects - self.annotations = annotations - self.custom_filters = custom_filters - self.limit = limit - self.offset = offset or 0 + self._q_objects = q_objects + self._annotations = annotations + self._custom_filters = custom_filters + self._limit = limit + self._offset = offset or 0 self._db = db - self.force_indexes = force_indexes - self.use_indexes = use_indexes + self._force_indexes = force_indexes + self._use_indexes = use_indexes def _make_query(self) -> None: self.query = copy(self.model._meta.basequery) - self.resolve_filters( - model=self.model, - q_objects=self.q_objects, - annotations=self.annotations, - custom_filters=self.custom_filters, - ) + self.resolve_filters() count_term = Count("*") if self.query._groupbys: count_term = count_term.over() + + # remove annotations + self.query._selects = [] self.query._select_other(count_term) - if self.force_indexes: + if self._force_indexes: self.query._force_indexes = [] - self.query = self.query.force_index(*self.force_indexes) - if self.use_indexes: + self.query = self.query.force_index(*self._force_indexes) + if self._use_indexes: self.query._use_indexes = [] - self.query = self.query.use_index(*self.use_indexes) + self.query = self.query.use_index(*self._use_indexes) def __await__(self) -> Generator[Any, None, int]: if self._db is None: @@ -1365,9 +1327,9 @@ async def _execute(self) -> int: _, result = await self._db.execute_query(str(self.query)) if not result: return 0 - count = list(dict(result[0]).values())[0] - self.offset - if self.limit and count > self.limit: - return self.limit + count = list(dict(result[0]).values())[0] - self._offset + if self._limit and count > self._limit: + return self._limit return count @@ -1376,7 +1338,7 @@ class FieldSelectQuery(AwaitableQuery): def __init__(self, model: Type[MODEL], annotations: Dict[str, Any]) -> None: super().__init__(model) - self.annotations = annotations + self._annotations = annotations def _join_table_with_forwarded_fields( self, model: Type[MODEL], table: Table, field: str, forwarded_fields: str @@ -1410,8 +1372,8 @@ def _join_table_with_forwarded_fields( def add_field_to_select_query(self, field: str, return_as: str) -> None: table = self.model._meta.basetable - if field in self.annotations: - self._annotations[return_as] = self.annotations[field] + if field in self._annotations: + self._annotations[return_as] = self._annotations[field] return if field in self.model._meta.fields_db_projection: @@ -1446,8 +1408,8 @@ def resolve_to_python_value(self, model: Type[MODEL], field: str) -> Callable: if field in (x[1] for x in model._meta.db_native_fields): return lambda x: x - if field in self.annotations: - annotation = self.annotations[field] + if field in self._annotations: + annotation = self._annotations[field] field_object = getattr(annotation, "field_object", None) if field_object: return field_object.to_python_value @@ -1483,21 +1445,18 @@ def _resolve_group_bys(self, *field_names: str): class ValuesListQuery(FieldSelectQuery, Generic[SINGLE]): __slots__ = ( - "flat", "fields", - "limit", - "offset", - "distinct", - "orderings", - "annotations", - "custom_filters", - "q_objects", - "single", - "raise_does_not_exist", - "fields_for_select_list", - "group_bys", - "force_indexes", - "use_indexes", + "_limit", + "_offset", + "_distinct", + "_orderings", + "_single", + "_raise_does_not_exist", + "_fields_for_select_list", + "_flat", + "_group_bys", + "_force_indexes", + "_use_indexes", ) def __init__( @@ -1525,20 +1484,20 @@ def __init__( fields_for_select = {str(i): field for i, field in enumerate(fields_for_select_list)} self.fields = fields_for_select - self.limit = limit - self.offset = offset - self.distinct = distinct - self.orderings = orderings - self.custom_filters = custom_filters - self.q_objects = q_objects - self.single = single - self.raise_does_not_exist = raise_does_not_exist - self.fields_for_select_list = fields_for_select_list - self.flat = flat + self._limit = limit + self._offset = offset + self._distinct = distinct + self._orderings = orderings + self._custom_filters = custom_filters + self._q_objects = q_objects + self._single = single + self._raise_does_not_exist = raise_does_not_exist + self._fields_for_select_list = fields_for_select_list + self._flat = flat self._db = db - self.group_bys = group_bys - self.force_indexes = force_indexes - self.use_indexes = use_indexes + self._group_bys = group_bys + self._force_indexes = force_indexes + self._use_indexes = use_indexes def _make_query(self) -> None: self._joined_tables = [] @@ -1550,30 +1509,25 @@ def _make_query(self) -> None: self.resolve_ordering( model=self.model, table=self.model._meta.basetable, - orderings=self.orderings, - annotations=self.annotations, - ) - self.resolve_filters( - model=self.model, - q_objects=self.q_objects, - annotations=self.annotations, - custom_filters=self.custom_filters, + orderings=self._orderings, + annotations=self._annotations, ) - if self.limit: - self.query._limit = self.limit - if self.offset: - self.query._offset = self.offset - if self.distinct: + self.resolve_filters() + if self._limit: + self.query._limit = self._limit + if self._offset: + self.query._offset = self._offset + if self._distinct: self.query._distinct = True - if self.group_bys: - self.query._groupbys = self._resolve_group_bys(*self.group_bys) + if self._group_bys: + self.query._groupbys = self._resolve_group_bys(*self._group_bys) - if self.force_indexes: + if self._force_indexes: self.query._force_indexes = [] - self.query = self.query.force_index(*self.force_indexes) - if self.use_indexes: + self.query = self.query.force_index(*self._force_indexes) + if self._use_indexes: self.query._use_indexes = [] - self.query = self.query.use_index(*self.use_indexes) + self.query = self.query.use_index(*self._use_indexes) @overload def __await__( @@ -1601,7 +1555,7 @@ async def _execute(self) -> Union[List[Any], Tuple]: (key, self.resolve_to_python_value(self.model, name)) for key, name in self.fields.items() ] - if self.flat: + if self._flat: func = columns[0][1] flatmap = lambda entry: func(entry["0"]) # noqa lst_values = list(map(flatmap, result)) @@ -1609,11 +1563,11 @@ async def _execute(self) -> Union[List[Any], Tuple]: listmap = lambda entry: tuple(func(entry[column]) for column, func in columns) # noqa lst_values = list(map(listmap, result)) - if self.single: + if self._single: if len(lst_values) == 1: return lst_values[0] if not lst_values: - if self.raise_does_not_exist: + if self._raise_does_not_exist: raise DoesNotExist(self.model) return None # type: ignore raise MultipleObjectsReturned(self.model) @@ -1622,19 +1576,16 @@ async def _execute(self) -> Union[List[Any], Tuple]: class ValuesQuery(FieldSelectQuery, Generic[SINGLE]): __slots__ = ( - "fields_for_select", - "limit", - "offset", - "distinct", - "orderings", - "annotations", - "custom_filters", - "q_objects", - "single", - "raise_does_not_exist", - "group_bys", - "force_indexes", - "use_indexes", + "_fields_for_select", + "_limit", + "_offset", + "_distinct", + "_orderings", + "_single", + "_raise_does_not_exist", + "_group_bys", + "_force_indexes", + "_use_indexes", ) def __init__( @@ -1656,54 +1607,55 @@ def __init__( use_indexes: Set[str], ) -> None: super().__init__(model, annotations) - self.fields_for_select = fields_for_select - self.limit = limit - self.offset = offset - self.distinct = distinct - self.orderings = orderings - self.custom_filters = custom_filters - self.q_objects = q_objects - self.single = single - self.raise_does_not_exist = raise_does_not_exist + self._fields_for_select = fields_for_select + self._limit = limit + self._offset = offset + self._distinct = distinct + self._orderings = orderings + self._custom_filters = custom_filters + self._q_objects = q_objects + self._single = single + self._raise_does_not_exist = raise_does_not_exist self._db = db - self.group_bys = group_bys - self.force_indexes = force_indexes - self.use_indexes = use_indexes + self._group_bys = group_bys + self._force_indexes = force_indexes + self._use_indexes = use_indexes def _make_query(self) -> None: self._joined_tables = [] self.query = copy(self.model._meta.basequery) - for return_as, field in self.fields_for_select.items(): + for return_as, field in self._fields_for_select.items(): self.add_field_to_select_query(field, return_as) self.resolve_ordering( model=self.model, table=self.model._meta.basetable, - orderings=self.orderings, - annotations=self.annotations, - ) - self.resolve_filters( - model=self.model, - q_objects=self.q_objects, - annotations=self.annotations, - custom_filters=self.custom_filters, + orderings=self._orderings, + annotations=self._annotations, ) - if self.limit: - self.query._limit = self.limit - if self.offset: - self.query._offset = self.offset - if self.distinct: + self.resolve_filters() + + # remove annotations that are not in fields_for_select + self.query._selects = [ + select for select in self.query._selects if select.alias in self._fields_for_select + ] + + if self._limit: + self.query._limit = self._limit + if self._offset: + self.query._offset = self._offset + if self._distinct: self.query._distinct = True - if self.group_bys: - self.query._groupbys = self._resolve_group_bys(*self.group_bys) + if self._group_bys: + self.query._groupbys = self._resolve_group_bys(*self._group_bys) - if self.force_indexes: + if self._force_indexes: self.query._force_indexes = [] - self.query = self.query.force_index(*self.force_indexes) - if self.use_indexes: + self.query = self.query.force_index(*self._force_indexes) + if self._use_indexes: self.query._use_indexes = [] - self.query = self.query.use_index(*self.use_indexes) + self.query = self.query.use_index(*self._use_indexes) @overload def __await__( @@ -1733,7 +1685,7 @@ async def _execute(self) -> Union[List[dict], Dict]: val for val in [ (alias, self.resolve_to_python_value(self.model, field_name)) - for alias, field_name in self.fields_for_select.items() + for alias, field_name in self._fields_for_select.items() ] if not isinstance(val[1], types.LambdaType) ] @@ -1743,11 +1695,11 @@ async def _execute(self) -> Union[List[dict], Dict]: for col, func in columns: row[col] = func(row[col]) - if self.single: + if self._single: if len(result) == 1: return result[0] if not result: - if self.raise_does_not_exist: + if self._raise_does_not_exist: raise DoesNotExist(self.model) return None # type: ignore raise MultipleObjectsReturned(self.model) @@ -1780,7 +1732,7 @@ def __await__(self) -> Generator[Any, None, List[MODEL]]: class BulkUpdateQuery(UpdateQuery, Generic[MODEL]): - __slots__ = ("objects", "fields", "batch_size", "queries") + __slots__ = ("fields", "_objects", "_batch_size", "_queries") def __init__( self, @@ -1805,34 +1757,29 @@ def __init__( limit=limit, orderings=orderings, ) - self.objects = objects self.fields = fields - self.batch_size = batch_size - self.queries: List[QueryBuilder] = [] + self._objects = objects + self._batch_size = batch_size + self._queries: List[QueryBuilder] = [] def _make_query(self) -> None: table = self.model._meta.basetable self.query = self._db.query_class.update(table) - if self.capabilities.support_update_limit_order_by and self.limit: - self.query._limit = self.limit + if self.capabilities.support_update_limit_order_by and self._limit: + self.query._limit = self._limit self.resolve_ordering( model=self.model, table=table, - orderings=self.orderings, - annotations=self.annotations, + orderings=self._orderings, + annotations=self._annotations, ) - self.resolve_filters( - model=self.model, - q_objects=self.q_objects, - annotations=self.annotations, - custom_filters=self.custom_filters, - ) + self.resolve_filters() executor = self._db.executor_class(model=self.model, db=self._db) pk_attr = self.model._meta.pk_attr source_pk_attr = self.model._meta.fields_map[pk_attr].source_field or pk_attr pk = Field(source_pk_attr) - for objects_item in chunk(self.objects, self.batch_size): + for objects_item in chunk(self._objects, self._batch_size): query = copy(self.query) for field in self.fields: case = Case() @@ -1856,30 +1803,30 @@ def _make_query(self) -> None: pk_list.append(value) query = query.set(field, case) query = query.where(pk.isin(pk_list)) - self.queries.append(query) + self._queries.append(query) async def _execute(self) -> int: count = 0 - for query in self.queries: + for query in self._queries: count += (await self._db.execute_query(str(query)))[0] return count def sql(self, **kwargs) -> str: self.as_query() - return ";".join([str(query) for query in self.queries]) + return ";".join([str(query) for query in self._queries]) class BulkCreateQuery(AwaitableQuery, Generic[MODEL]): __slots__ = ( - "objects", - "ignore_conflicts", - "batch_size", + "_objects", + "_ignore_conflicts", + "_batch_size", "_db", - "executor", - "insert_query", - "insert_query_all", - "update_fields", - "on_conflict", + "_executor", + "_insert_query", + "_insert_query_all", + "_update_fields", + "_on_conflict", ) def __init__( @@ -1893,70 +1840,68 @@ def __init__( on_conflict: Optional[Iterable[str]] = None, ): super().__init__(model) - self.objects = objects - self.ignore_conflicts = ignore_conflicts - self.batch_size = batch_size + self._objects = objects + self._ignore_conflicts = ignore_conflicts + self._batch_size = batch_size self._db = db - self.update_fields = update_fields - self.on_conflict = on_conflict + self._update_fields = update_fields + self._on_conflict = on_conflict def _make_query(self) -> None: - self.executor = self._db.executor_class(model=self.model, db=self._db) - if self.ignore_conflicts or self.update_fields: - regular_columns, columns = self.executor._prepare_insert_columns() - self.insert_query = self.executor._prepare_insert_statement( - columns, ignore_conflicts=self.ignore_conflicts + self._executor = self._db.executor_class(model=self.model, db=self._db) + if self._ignore_conflicts or self._update_fields: + _, columns = self._executor._prepare_insert_columns() + self._insert_query = self._executor._prepare_insert_statement( + columns, ignore_conflicts=self._ignore_conflicts ) - self.insert_query_all = self.insert_query + self._insert_query_all = self._insert_query if self.model._meta.generated_db_fields: - regular_columns_all, columns_all = self.executor._prepare_insert_columns( - include_generated=True - ) - self.insert_query_all = self.executor._prepare_insert_statement( + _, columns_all = self._executor._prepare_insert_columns(include_generated=True) + self._insert_query_all = self._executor._prepare_insert_statement( columns_all, has_generated=False, - ignore_conflicts=self.ignore_conflicts, + ignore_conflicts=self._ignore_conflicts, ) - if self.update_fields: + if self._update_fields: alias = f"new_{self.model._meta.db_table}" - self.insert_query_all = self.insert_query_all.as_(alias).on_conflict( - *self.on_conflict + self._insert_query_all = self._insert_query_all.as_(alias).on_conflict( + *self._on_conflict ) - self.insert_query = self.insert_query.as_(alias).on_conflict(*self.on_conflict) - for update_field in self.update_fields: - self.insert_query_all = self.insert_query_all.do_update(update_field) - self.insert_query = self.insert_query.do_update(update_field) + self._insert_query = self._insert_query.as_(alias).on_conflict(*self._on_conflict) + for update_field in self._update_fields: + self._insert_query_all = self._insert_query_all.do_update(update_field) + self._insert_query = self._insert_query.do_update(update_field) else: - self.insert_query_all = self.executor.insert_query_all - self.insert_query = self.executor.insert_query + self._insert_query_all = self._executor.insert_query_all + self._insert_query = self._executor.insert_query async def _execute(self) -> None: - for instance_chunk in chunk(self.objects, self.batch_size): + for instance_chunk in chunk(self._objects, self._batch_size): values_lists_all = [] values_lists = [] for instance in instance_chunk: if instance._custom_generated_pk: values_lists_all.append( [ - self.executor.column_map[field_name]( + self._executor.column_map[field_name]( getattr(instance, field_name), instance ) - for field_name in self.executor.regular_columns_all + for field_name in self._executor.regular_columns_all ] ) else: values_lists.append( [ - self.executor.column_map[field_name]( + self._executor.column_map[field_name]( getattr(instance, field_name), instance ) - for field_name in self.executor.regular_columns + for field_name in self._executor.regular_columns ] ) if values_lists_all: - await self._db.execute_many(str(self.insert_query_all), values_lists_all) + await self._db.execute_many(str(self._insert_query_all), values_lists_all) if values_lists: - await self._db.execute_many(str(self.insert_query), values_lists) + await self._db.execute_many(str(self._insert_query), values_lists) def __await__(self) -> Generator[Any, None, None]: if self._db is None: @@ -1966,6 +1911,6 @@ def __await__(self) -> Generator[Any, None, None]: def sql(self, **kwargs) -> str: self.as_query() - if self.insert_query and self.insert_query_all: - return ";".join([str(self.insert_query), str(self.insert_query_all)]) - return str(self.insert_query or self.insert_query_all) + if self._insert_query and self._insert_query_all: + return ";".join([str(self._insert_query), str(self._insert_query_all)]) + return str(self._insert_query or self._insert_query_all)