Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions docs/usage/crud.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,34 @@ print(f"用户ID: {user.id}") # 可以立即获取主键
### 批量创建

```python
# 批量创建
# 使用 Pydantic 模型批量创建
users_data = [
UserCreate(name="用户1", email="user1@example.com"),
UserCreate(name="用户2", email="user2@example.com"),
UserCreate(name="用户3", email="user3@example.com")
]
users = await user_crud.create_models(session, users_data)

# 使用字典批量创建(高性能方式)
# 使用字典批量创建
users_dict = [
{"name": "用户4", "email": "user4@example.com"},
{"name": "用户5", "email": "user5@example.com"}
]
users = await user_crud.bulk_create_models(session, users_dict)
users = await user_crud.create_models(session, users_dict)

# 混合使用 Pydantic 模型和字典
users_mixed = [
UserCreate(name="用户6", email="user6@example.com"),
{"name": "用户7", "email": "user7@example.com"},
]
users = await user_crud.create_models(session, users_mixed)

# 使用字典批量创建(高性能方式)
users_bulk = [
{"name": "用户8", "email": "user8@example.com"},
{"name": "用户9", "email": "user9@example.com"}
]
users = await user_crud.bulk_create_models(session, users_bulk)
```

## 查询操作
Expand Down
16 changes: 8 additions & 8 deletions sqlalchemy_crud_plus/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _get_pk_filter(self, pk: Any | list[Any]) -> list[ColumnExpressionArgument[b
async def create_model(
self,
session: AsyncSession,
obj: CreateSchema,
obj: CreateSchema | dict[str, Any],
flush: bool = False,
commit: bool = False,
**kwargs,
Expand All @@ -79,13 +79,13 @@ async def create_model(
Create a new instance of a model.

:param session: The SQLAlchemy async session
:param obj: The Pydantic schema containing data to be saved
:param obj: A Pydantic schema or dictionary containing the data to be saved
:param flush: If `True`, flush all object changes to the database
:param commit: If `True`, commits the transaction immediately
:param kwargs: Additional model data not included in the pydantic schema
:param kwargs: Additional model data not included in the pydantic schema or dict
:return:
"""
obj_data = obj.model_dump()
obj_data = obj if isinstance(obj, dict) else obj.model_dump()
if kwargs:
obj_data.update(kwargs)

Expand All @@ -102,7 +102,7 @@ async def create_model(
async def create_models(
self,
session: AsyncSession,
objs: list[CreateSchema],
objs: list[CreateSchema | dict[str, Any]],
flush: bool = False,
commit: bool = False,
**kwargs,
Expand All @@ -111,15 +111,15 @@ async def create_models(
Create new instances of a model.

:param session: The SQLAlchemy async session
:param objs: The Pydantic schema list containing data to be saved
:param objs: A list of Pydantic schemas or dictionaries containing the data to be saved
:param flush: If `True`, flush all object changes to the database
:param commit: If `True`, commits the transaction immediately
:param kwargs: Additional model data not included in the pydantic schema
:param kwargs: Additional model data not included in the pydantic schema or dict
:return:
"""
ins_list = []
for obj in objs:
obj_data = obj.model_dump()
obj_data = obj if isinstance(obj, dict) else obj.model_dump()
if kwargs:
obj_data.update(kwargs)
ins = self.model(**obj_data)
Expand Down
96 changes: 96 additions & 0 deletions tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,99 @@ async def test_bulk_create_models_with_commit(db: AsyncSession, crud_ins: CRUDPl
assert len(results) == 2
assert results[0].name == 'bulk_commit_1'
assert results[1].name == 'bulk_commit_2'


@pytest.mark.asyncio
async def test_create_model_with_dict(db: AsyncSession, crud_ins: CRUDPlus[Ins]):
async with db.begin():
data = {'name': 'dict_item'}
result = await crud_ins.create_model(db, data)

assert result.name == 'dict_item'
assert result.id is not None


@pytest.mark.asyncio
async def test_create_model_with_dict_and_flush(db: AsyncSession, crud_ins: CRUDPlus[Ins]):
async with db.begin():
data = {'name': 'dict_flush_item'}
result = await crud_ins.create_model(db, data, flush=True)

assert result.name == 'dict_flush_item'
assert result.id is not None


@pytest.mark.asyncio
async def test_create_model_with_dict_and_commit(db: AsyncSession, crud_ins: CRUDPlus[Ins]):
data = {'name': 'dict_commit_item'}
result = await crud_ins.create_model(db, data, commit=True)

assert result.name == 'dict_commit_item'
assert result.id is not None


@pytest.mark.asyncio
async def test_create_model_with_dict_and_kwargs(db: AsyncSession, crud_ins: CRUDPlus[Ins]):
async with db.begin():
data = {'name': 'dict_kwargs_item'}
result = await crud_ins.create_model(db, data, is_deleted=True)

assert result.name == 'dict_kwargs_item'
assert result.is_deleted is True


@pytest.mark.asyncio
async def test_create_models_with_dict(db: AsyncSession, crud_ins: CRUDPlus[Ins]):
async with db.begin():
data = [{'name': f'dict_batch_{i}'} for i in range(3)]
results = await crud_ins.create_models(db, data)

assert len(results) == 3
assert all(r.name.startswith('dict_batch_') for r in results)
assert all(r.id is not None for r in results)


@pytest.mark.asyncio
async def test_create_models_with_dict_and_flush(db: AsyncSession, crud_ins: CRUDPlus[Ins]):
async with db.begin():
data = [{'name': f'dict_flush_batch_{i}'} for i in range(2)]
results = await crud_ins.create_models(db, data, flush=True)

assert len(results) == 2
assert all(r.id is not None for r in results)


@pytest.mark.asyncio
async def test_create_models_with_dict_and_commit(db: AsyncSession, crud_ins: CRUDPlus[Ins]):
data = [{'name': f'dict_commit_batch_{i}'} for i in range(2)]
results = await crud_ins.create_models(db, data, commit=True)

assert len(results) == 2
assert all(r.id is not None for r in results)


@pytest.mark.asyncio
async def test_create_models_with_dict_and_kwargs(db: AsyncSession, crud_ins: CRUDPlus[Ins]):
async with db.begin():
data = [{'name': f'dict_kwargs_batch_{i}'} for i in range(2)]
results = await crud_ins.create_models(db, data, is_deleted=True)

assert len(results) == 2
assert all(r.is_deleted is True for r in results)


@pytest.mark.asyncio
async def test_create_models_with_mixed_input(db: AsyncSession, crud_ins: CRUDPlus[Ins]):
async with db.begin():
data = [
CreateIns(name='schema_item'),
{'name': 'dict_item'},
CreateIns(name='schema_item_2'),
]
results = await crud_ins.create_models(db, data)

assert len(results) == 3
assert results[0].name == 'schema_item'
assert results[1].name == 'dict_item'
assert results[2].name == 'schema_item_2'
assert all(r.id is not None for r in results)