diff --git a/.gitignore b/.gitignore index fae75a3..1547f37 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ build dist pydbml.egg-info .mypy_cache +.coverage +.eggs +.idea diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..3f820d8 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,168 @@ +# 1.2.0 +- Fix: Temporarily disable unicode characters support in identifiers for performance (#59) + +# 1.1.4 +- Fix: Remove trailing comma in Enum SQL (#58 thanks @ralfschulze for reporting) + +# 1.1.3 +- New: support more index types. Thanks @pierresouchay for the contribution. + +# 1.1.2 +- Fix: escaping single quotes in a column's default value. Thanks @ryanproback for the contribution +- Fix: TableGroup and Project name are now safely quoted on render. Thanks @ryanproback for reporting +- Fix: line breaks in column and index options are now allowed. Thanks @aardjon for reporting +- Fix: table elements order is now not enforced by the parser. Thanks @aardjon for reporting +- New: TableGroup now can have notes (DBML v.3.7.2) +- New: TableGroup now can have color (DBML v.3.7.4) + +# 1.1.1 +- New: SQL and DBML renderers can now be supplied to parser + +# 1.1.0 + +- New: SQL and DBML rendering rewritten tow support external renderers +- New: allow unicode characters in identifiers (DBML v3.3.0) +- New: support for arbitrary table and column properties (#37) + +# 1.0.11 +- Fix: allow pk in named indexes (thanks @pierresouchay for the contribution) + +# 1.0.10 +- New: Sticky notes syntax (DBML v3.2.0) +- Fix: Table header color was not rendered in `dbml()` (thanks @tristangrebot for the contribution) +- New: allow array column types (DBML v3.1.0) +- New: allow double quotes in expressions (DBML v3.1.2) +- Fix: recursion in object equality check +- New: don't allow duplicate refs even if they have different inline method (DBML v3.1.6) + +# 1.0.9 + +- Fix: enum collision from different schemas. Thanks @ewdurbin for the contribution + +# 1.0.8 + +- Fix: (#27) allowing comments after Tables, Enums, etc. Thanks @marktaff for reporting + +# 1.0.7 + +- Fix: removing indentation bug + +# 1.0.6 + +- Fix: (#26) bug in note empty line stripping, thanks @Jaschenn for reporting +- New: get_references_for_sql table method + +# 1.0.5 + +- Fix: junction table now has the schema of the first referenced table (as introduced in DBML 2.4.3) +- Fix: typing issue which failed for Python 3.8 and Python 3.9 + +# 1.0.4 + +- New: referenced tables in SQL are now defined first in SQL (#23 reported by @minhl) +- Fix: single quotes were not escaped in column notes (#24 reported by @fivegrant) + +# 1.0.3 + +- Fix: inline many-to-many references were not rendered in sql + +# 1.0.2 + +- New: "backslash newline" is supported in note text (line continuation) +- New: notes have reference to their parent. Note.sql now depends on type of parent (for tables and columns it's COMMENT ON clause) +- New: pydbml no longer splits long notes into multiple lines +- Fix: inline ref schema bug, thanks to @jens-koster +- Fix: (#16) notes were not idempotent, thanks @jens-koster for reporting +- Fix: (#15) note objects were not supported in project definition, thanks @jens-koster for reporting +- Fix: (#20) schema didn't work in table group definition, thanks @mjfii for reporting +- Fix: quotes in note text broke sql and dbml +- New: proper support of composite primary keys without creating an index +- New: support of many-to-many relationships + +# 1.0.1 + +- Fixed setup.py, thanks to @vosskj03. + +# 1.0.0 + +- New project architecture, full support for creating and editing DBML. See details in [Upgrading to PyDBML 1.0.0](docs/upgrading.md) +- New Expression class +- Support DBML 2.4.1 syntax: + - Multiline comments + - Multiple schemas + +# 0.4.2 + +- Fix: after editing column name index dbml was not updated. +- Fix: enums with spaces in name were not applied. +- Fix: after editing column name table dict was not updated. +- Fix: after editing enum column type was not updated. +- Removed EnumType class. Only Enum is used now. + +# 0.4.1 + +- Reworked `__repr__` and `__str__` methods on all classes. They are now much simplier and more readable. +- Comments on classes are now rendered as SQL comments in `sql` property (previously notes were rendered as comments on some classes). +- Notes on `Table` and `Column` classes are rendered as SQL comments in `sql` property: `COMMENT ON TABLE "x" is 'y'`. +- New: `dbml` property on most classes and on parsed results which returns the DBML code. +- Fix: sql for Reference and TableReference. + +# 0.4.0 + +- New: Support composite references. **Breaks backward compatibility!** `col1`, `col2` attributes on `Reference` and `col`, `ref_col` attributes on `TableReference` are now lists of `Column` instead of `Column`. +- `TableGroup` now holds references to actual tables. + +# 0.3.5 + +- New: Support references by aliases. +- New: Support indexes with expressions. +- New: You can now compare SQLObjects of same class. +- New: Add check for duplicate references on a table. +- Fix: minor bug fixes. + +# 0.3.4 + +- Notes are now added as comments in SQL for tables, table columns, indeces, enums. + +# 0.3.3 + +- Fix: bug in TableReference +- Fix: if schema had newline or comment at the end, it crashed parser + +# 0.3.2 + +- Fix TableReference sql + +# 0.3.1 + +- Fix: files in **UTF-8 with BOM** encoding couldn't be parsed. + +# 0.3 + +- More tests and more bug fixes. +- Added index columns validation. +- Added table group items validation. +- References now contain link to Table and Column objects instead of just names. +- Indexes now contain link to Column objects instead of just names. + +# 0.2 + +- Better syntax errors. +- sql for each object now contains in `sql` property instead of string rerpresentation. Added proper string representations. +- Added syntax tests. +- Million bugs fixed after testing. + +# 0.1.1 + +- Comments are now parsed too if they are before [b] or on the same line [l] as the entity. Works for: tables[b], columns[lb], references [lb], indexes[lb], enum items [lb], enums [b], project [b] and table group [b] +- All class instances will now have an empty Note in `note` attribute instead of None. +- Add string representation for Note and EnumItem. +- Enum instance now acts like list of EnumItems. +- Add EnumType to use in column.type attribute. +- Column type is now replaced by EnumType instance if enum with such name is defined. +- Remove unnecessary ColumnType class. +- Fix: note definition, project definition, some other definitions + +# 0.1 + +- Initial release diff --git a/LICENSE b/LICENSE index 1637db2..c13f991 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2021 +Copyright (c) 2024 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 56771a7..d30ecb6 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,21 @@ -[![](https://img.shields.io/pypi/v/pydbml.svg)](https://pypi.org/project/pydbml/) [![](https://img.shields.io/github/v/tag/Vanderhoof/PyDBML.svg?label=GitHub)](https://github.com/Vanderhoof/PyDBML) +[![](https://img.shields.io/pypi/v/pydbml.svg)](https://pypi.org/project/pydbml/) [![](https://img.shields.io/pypi/dm/pydbml.svg)](https://pypi.org/project/pydbml/) [![](https://img.shields.io/github/v/tag/Vanderhoof/PyDBML.svg?label=GitHub)](https://github.com/Vanderhoof/PyDBML) ![](coverage.svg) # DBML parser for Python -*Compliant with DBML **v2.3.0** syntax* +*Compliant with DBML **v3.9.5** syntax* -PyDBML is a Python parser for [DBML](https://www.dbml.org) syntax. +PyDBML is a Python parser and builder for [DBML](https://www.dbml.org) syntax. + +> The project was rewritten in May 2022, the new version 1.0.0 is not compatible with versions 0.x.x. See details in [Upgrading to PyDBML 1.0.0](docs/upgrading.md). + +**Docs:** + +* [Class Reference](docs/classes.md) +* [Creating DBML schema](docs/creating_schema.md) +* [Upgrading to PyDBML 1.0.0](docs/upgrading.md) +* [Arbitrary Properties](docs/properties.md) + +> PyDBML requires Python v3.8 or higher ## Installation @@ -16,7 +27,7 @@ pip3 install pydbml ## Quick start -Import the `PyDBML` class and initialize it with path to DBML-file: +To parse a DBML file, import the `PyDBML` class and initialize it with Path object ```python >>> from pydbml import PyDBML @@ -25,21 +36,27 @@ Import the `PyDBML` class and initialize it with path to DBML-file: ``` -or with file stream: +or with file stream + ```python >>> with open('test_schema.dbml') as f: ... parsed = PyDBML(f) ``` -or with entire source string: +or with entire source string + ```python >>> with open('test_schema.dbml') as f: ... source = f.read() >>> parsed = PyDBML(source) +>>> parsed + ``` +The parser returns a Database object that is a container for the parsed DBML entities. + You can access tables inside the `tables` attribute: ```python @@ -55,24 +72,24 @@ countries ``` -Or just by getting items by index or table name: +Or just by getting items by index or full table name: ```python ->>> parsed['countries'] - >>> parsed[1] -
+
+>>> parsed['public.countries'] +
``` -Other meaningful attributes are: +Other attributes are: * **refs** — list of all references, * **enums** — list of all enums, * **table_groups** — list of all table groups, * **project** — the Project object, if was defined. -Finally, you can get the SQL for your DBML schema by accessing `sql` property: +Generate SQL for your DBML Database by accessing the `sql` property: ```python >>> print(parsed.sql) # doctest:+ELLIPSIS @@ -82,215 +99,58 @@ CREATE TYPE "orders_status" AS ENUM ( 'done', 'failure', ); + CREATE TYPE "product status" AS ENUM ( 'Out of Stock', 'In Stock', ); + CREATE TABLE "orders" ( "id" int PRIMARY KEY AUTOINCREMENT, "user_id" int UNIQUE NOT NULL, - "status" orders_status, + "status" "orders_status", "created_at" varchar ); ... ``` -# Docs - -## Table class - -After running parser all tables from the schema are stored in `tables` attribute of the `PyDBMLParseResults` object. - -```python ->>> from pydbml import PyDBML ->>> parsed = PyDBML.parse_file('test_schema.dbml') ->>> table = parsed.tables[0] ->>> table -
- -``` - -Important attributes of the `Table` object are: - -* **name** (str) — table name, -* **refs** (list of `TableReference`) — all foreign keys, defined for the table, -* **columns** (list of `Column`) — table columns, -* **indexes** (list of `Index`) — indexes, defined for the table. -* **alias** (str) — table alias, if defined. -* **note** (str) — note for table, if defined. -* **header_color** (str) — the header_color param, if defined. -* **comment** (str) — comment, if it was added just before table definition. - -`Table` object may act as a list or a dictionary of columns: - -```python ->>> table[0] - ->>> table['status'] - - -``` - -## Column class - -Table columns are stored in the `columns` attribute of a `Table` object. - -Important attributes of the `Column` object are: - -* **name** (str) — column name, -* **table** (Table)— link to `Table` object, which holds this column. -* **type** (str or `Enum`) — column type. If type is a enum, defined in the same schema, this attribute will hold a link to corresponding `Enum` object. -* **unique** (bool) — is column unique. -* **not_null** (bool) — is column not null. -* **pk** (bool) — is column a primary key. -* **autoinc** (bool) — is an autoincrement column. -* **default** (str or int or float) — column's default value. -* **note** (Note) — column's note if was defined. -* **comment** (str) — comment, if it was added just before column definition or right after it on the same line. - -## Index class - -Indexes are stored in the `indexes` attribute of a `Table` object. - -Important attributes of the `Index` object are: - -* **subjects** (list of `Column` or `str`) — list subjects which are indexed. Columns are represented by `Column` objects, expressions (`getdate()`) are stored as strings `(getdate())`. Expressions are supported since **0.3.5**. -* **table** (`Table`) — table, for which this index is defined. -* **name** (str) — index name, if defined. -* **unique** (bool) — is index unique. -* **type** (str) — index type, if defined. Can be either `hash` or `btree`. -* **pk** (bool) — is this a primary key index. -* **note** (note) — index note, if defined. -* **comment** (str) — comment, if it was added just before index definition. - -## Reference class - -After running parser all references from the schema are stored in `refs` attribute of the `PyDBMLParseResults` object. - -```python ->>> from pydbml import PyDBML ->>> parsed = PyDBML.parse_file('test_schema.dbml') ->>> parsed.refs[0] - - -``` - -Important attributes of the `Reference` object are: - -* **type** (str) — reference type, in DBML syntax: - * `<` — one to many; - * `>` — many to one; - * `-` — one to one. -* **table1** (`Table`) — link to the first table of the reference. -* **col1** (list os `Column`) — list of Column objects of the first table of the reference. Changed in **0.4.0**, previously was plain `Column`. -* **table2** (`Table`) — link to the second table of the reference. -* **col2** (list of `Column`) — list of Column objects of the second table of the reference. Changed in **0.4.0**, previously was plain `Column`. -* **name** (str) — reference name, if defined. -* **on_update** (str) — reference's on update setting, if defined. -* **on_delete** (str) — reference's on delete setting, if defined. -* **comment** (str) — comment, if it was added before reference definition. - -## TableReference class - -Apart from `Reference` objects, parser also creates `TableReference` objects, which are stored in each table, where the foreign key should be defined. These objects don't have types. List of references is stored in `refs` attribute of a Table object: - -```python ->>> from pydbml import PyDBML ->>> parsed = PyDBML.parse_file('test_schema.dbml') ->>> order_items_refs = parsed.tables[1].refs ->>> order_items_refs[0] - - -``` - -Important attributes of the `TableReference` object are: - -* **col** (list[`Column`]) — list of Column objects, which are referenced in this table. Changed in **0.4.0**, previously was plain `Column`. -* **ref_table** (`Table`) — link to the second table of the reference. -* **ref_col** (list[`Column`]) — list of Column objects, which are referenced by this table. Changed in **0.4.0**, previously was plain `Column`. -* **name** (str) — reference name, if defined. -* **on_update** (str) — reference's on update setting, if defined. -* **on_delete** (str) — reference's on delete setting, if defined. - -## Enum class - -After running parser all enums from the schema are stored in `enums` attribute of the `PyDBMLParseResults` object. - -```python ->>> from pydbml import PyDBML ->>> parsed = PyDBML.parse_file('test_schema.dbml') ->>> enum = parsed.enums[0] ->>> enum - - -``` - -`Enum` object contains three attributes: - -* **name** (str) — enum name, -* **items** (list of `EnumItem`) — list of items. -* **comment** (str) — comment, which was defined before enum definition. - -Enum objects also act as a list of items: - -```python ->>> enum[0] - - -``` - -### EnumItem class - -Enum items are stored in the `items` property of a `Enum` class. - -`EnumItem` object contains following attributes: - -* **name** (str) — enum item name, -* **note** (`Note`) — enum item note, if was defined. -* **comment** (str) — comment, which was defined before enum item definition or right after it on the same line. - -## Note class - -Note is a basic class, which may appear in some other classes' `note` attribute. It has just one meaningful attribute: - -**text** (str) — note text. - -## Project class - -After running parser the project info is stored in the `project` attribute of the `PyDBMLParseResults` object. - -```python ->>> from pydbml import PyDBML ->>> parsed = PyDBML.parse_file('test_schema.dbml') ->>> parsed.project - - -``` - -Attributes of the `Project` object: - -* **name** (str) — project name, -* **items** (str) — dictionary with project items, -* **note** (`Note`) — note, if was defined, -* **comment** (str) — comment, if was added before project definition. - -## TableGroup class - -After running parser the project info is stored in the `project` attribute of the `PyDBMLParseResults` object. +Generate DBML for your Database by accessing the `dbml` property: ```python ->>> from pydbml import PyDBML ->>> parsed = PyDBML.parse_file('test_schema.dbml') ->>> parsed.table_groups -[, ] +>>> parsed.project.items['author'] = 'John Doe' +>>> print(parsed.dbml) # doctest:+ELLIPSIS +Project "test_schema" { + author: 'John Doe' + Note { + 'This schema is used for PyDBML doctest' + } +} + +Enum "orders_status" { + "created" + "running" + "done" + "failure" +} + +Enum "product status" { + "Out of Stock" + "In Stock" +} + +Table "orders" [headercolor: #fff] { + "id" int [pk, increment] + "user_id" int [unique, not null] + "status" "orders_status" + "created_at" varchar +} + +Table "order_items" { + "order_id" int + "product_id" int + "quantity" int [default: 1] +} +... ``` - -Attributes of the `TableGroup` object: - -* **name** (str) — table group name, -* **items** (str) — dictionary with tables in the group, -* **comment** (str) — comment, if was added before table group definition. - -> TableGroup `items` parameter initially holds just the names of the tables, but after parsing the whole document, `PyDBMLParseResults` class replaces them with references to actual tables. diff --git a/TODO.md b/TODO.md deleted file mode 100644 index 40aaea2..0000000 --- a/TODO.md +++ /dev/null @@ -1,3 +0,0 @@ -- Better __repr__, now they are too long and unreadable, -- Test all __repr__ and __str__, -- Add testcoverage. \ No newline at end of file diff --git a/changelog.md b/changelog.md deleted file mode 100644 index d297cf8..0000000 --- a/changelog.md +++ /dev/null @@ -1,59 +0,0 @@ -# 0.4.0 - -- New: Support composite references. **Breaks backward compatibility!** `col1`, `col2` attributes on `Reference` and `col`, `ref_col` attributes on `TableReference` are now lists of `Column` instead of `Column`. -- `TableGroup` now holds references to actual tables. - -# 0.3.5 - -- New: Support references by aliases. -- New: Support indexes with expressions. -- New: You can now compare SQLObjects of same class. -- New: Add check for duplicate references on a table. -- Fix: minor bug fixes. - -# 0.3.4 - -- Notes are now added as comments in SQL for tables, table columns, indeces, enums. - -# 0.3.3 - -- Fix: bug in TableReference -- Fix: if schema had newline or comment at the end, it crashed parser - -# 0.3.2 - -- Fix TableReference sql - -# 0.3.1 - -- Fix: files in **UTF-8 with BOM** encoding couldn't be parsed. - -# 0.3 - -- More tests and more bug fixes. -- Added index columns validation. -- Added table group items validation. -- References now contain link to Table and Column objects instead of just names. -- Indexes now contain link to Column objects instead of just names. - -# 0.2 - -- Better syntax errors. -- sql for each object now contains in `sql` property instead of string rerpresentation. Added proper string representations. -- Added syntax tests. -- Million bugs fixed after testing. - -# 0.1.1 - -- Comments are now parsed too if they are before [b] or on the same line [l] as the entity. Works for: tables[b], columns[lb], references [lb], indexes[lb], enum items [lb], enums [b], project [b] and table group [b] -- All class instances will now have an empty Note in `note` attribute instead of None. -- Add string representation for Note and EnumItem. -- Enum instance now acts like list of EnumItems. -- Add EnumType to use in column.type attribute. -- Column type is now replaced by EnumType instance if enum with such name is defined. -- Remove unnecessary ColumnType class. -- Fix: note definition, project definition, some other definitions - -# 0.1 - -- Initial release diff --git a/coverage.svg b/coverage.svg new file mode 100644 index 0000000..e5db27c --- /dev/null +++ b/coverage.svg @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + coverage + coverage + 100% + 100% + + diff --git a/docs/classes.md b/docs/classes.md new file mode 100644 index 0000000..d4605bd --- /dev/null +++ b/docs/classes.md @@ -0,0 +1,329 @@ + +* [Database](#database) +* [Table](#table) +* [Column](#column) +* [Index](#index) +* [Reference](#reference) +* [Enum](#enum) +* [Note](#note) +* [StickyNote](#sticky_note) +* [Expression](#expression) +* [Project](#project) +* [TableGroup](#tablegroup) + +# Class Reference + +PyDBML classes represent database entities. They live in the `pydbml.classes` package. + +```python +>>> from pydbml.classes import Table, Column, Reference + +``` + +The `Database` class represents a PyDBML database. You can import it from the `pydbml` package. + +```python +>>> from pydbml import Database + +``` + +## Database + +`Database` is the main class, representing a PyDBML database. When PyDBML parses a .dbml file, it returns a `Database` object. This object holds all objects of the database and makes sure they are properly connected. You can access the `Database` object by calling the `database` property of each class (except child classes like `Column` or `Index`). + +When you are creating PyDBML schema from scratch, you have to add each created object to the database by calling `Database.add`. + +`Database` object may act as a list or a dictionary of tables: + +```python +>>> from pydbml import PyDBML +>>> db = PyDBML.parse_file('test_schema.dbml') +>>> table = db.tables[0] +>>> db['public.orders'] +
+>>> db[0] +
+ +``` + +### Attributes + +* **tables** (list of `Table`) — list of all `Table` objects, defined in this database. +* **table_dict** (dict of `Table`) — dictionary holding database `Table` objects. The key is full table name (with schema: `public.mytable`) or a table alias (`myalias`). +* **refs** (list of `Reference`) — list of all `Reference` objects, defined in this database. +* **enums** (list of `Enum`) — list of all `Enum` objects, defined in this database. +* **table_groups** (list of `TableGroup`) — list of all `TableGroup` objects, defined in this database. +* **project** (`Project`) — database `Project`. +* **sql** () — SQL definition for this database. +* **dbml** () — DBML definition for this table. + +### Methods + +* **add** (PyDBML object) — add a PyDBML object to the database. +* **add_table** (`Table`) — add a `Table` object to the database. +* **add_reference** (`Reference`) — add a `Reference` object to the database. +* **add_enum** (`Enum`) — add a `Enum` object to the database. +* **add_table_group** (`TableGroup`) — add a `TableGroup` object to the database. +* **add_project** (`Project`) — add a `Project` object to the database. +* **delete** (PyDBML object) — delete a PyDBML object from the database. +* **delete_table** (`Table`) — delete a `Table` object from the database. +* **delete_reference** (`Reference`) — delete a `Reference` object from the database. +* **delete_enum** (`Enum`) — delete a `Enum` object from the database. +* **delete_table_group** (`TableGroup`) — delete a `TableGroup` object from the database. +* **delete_project** (`Project`) — delete a `Project` object from the database. + +## Table + +`Table` class represents a database table. + +```python +>>> from pydbml import PyDBML +>>> parsed = PyDBML.parse_file('test_schema.dbml') +>>> table = parsed.tables[0] +>>> table +
+ +``` + +`Table` object may act as a list or a dictionary of columns: + +```python +>>> table[0] + +>>> table['status'] + + +``` + +### Attributes + +* **database** (`Database`) — link to the table's database object, if it was set. +* **name** (str) — table name. +* **schema** (str) — table schema name. +* **full_name** (str) — table name with schema prefix. +* **columns** (list of `Column`) — table columns. +* **indexes** (list of `Index`) — indexes, defined for the table. +* **alias** (str) — table alias, if defined. +* **note** (str) — note for table, if defined. +* **header_color** (str) — the header_color param, if defined. +* **comment** (str) — comment, if it was added just before table definition. +* **sql** (str) — SQL definition for this table. +* **dbml** (str) — DBML definition for this table. + +### Methods + +* **add_column** (c: `Column`) — add a column to the table, +* **delete_column** (c: `Column` or int) — delete a column from the table by Column object or column index. +* **add_index** (i: `Index`) — add an index to the table, +* **delete_index** (i: Index or int) — delete an index from the table by Index object or index number. +* **get_refs** — get list of references, defined for this table. +* **get_references_for_sql** — get list of references where this table is on the left side of FOREIGN KEY definition in SQL. + +## Column + +`Column` class represents a column of a database table. + +Table columns are stored in the `columns` attribute of a `Table` object. + +### Attributes + +* **database** (`Database`) — link to the database object of this column's table, if it was set. +* **name** (str) — column name, +* **table** (`Table`) — link to `Table` object, which holds this column. +* **type** (str or `Enum`) — column type. If type is a enum, this attribute will hold a link to corresponding `Enum` object. +* **unique** (bool) — indicates whether the column is unique. +* **not_null** (bool) — indicates whether the column is not null. +* **pk** (bool) — indicates whether the column is a primary key. +* **autoinc** (bool) — indicates whether this is an autoincrement column. +* **default** (str or bool or int or float or Expression) — column's default value. +* **note** (Note) — column's note if was defined. +* **comment** (str) — comment, if it was added just before column definition or right after it on the same line. +* **sql** (str) — SQL definition for this column. +* **dbml** (str) — DBML definition for this column. + +### Methods + +* **get_refs** — get list of references, defined for this column. + +## Index + +`Index` class represents an index of a database table. + +Indexes are stored in the `indexes` attribute of a `Table` object. + +### Attributes + +* **subjects** (list of `Column` or `Expression`) — list subjects which are indexed. Columns are represented by `Column` objects or `Expression` objects. +* **subject_names** (list of str) — list of index subject names. +* **table** (`Table`) — link to table, for which this index is defined. +* **name** (str) — index name, if defined. +* **unique** (bool) — indicates whether the index is unique. +* **type** (str) — index type, if defined. Accepted values: `brin`, `btree`, `gin`, `gist`, `hash`, `spgist`. +* **pk** (bool) — indicates whether this a primary key index. +* **note** (note) — index note, if defined. +* **comment** (str) — comment, if it was added just before index definition. +* **sql** (str) — SQL definition for this index. +* **dbml** (str) — DBML definition for this index. + +## Reference + +`Reference` class represents a database relation. + +```python +>>> from pydbml import PyDBML +>>> parsed = PyDBML.parse_file('test_schema.dbml') +>>> parsed.refs[0] + + +``` + +### Attributes + +* **database** (`Database`) — link to the reference's database object, if it was set. +* **type** (str) — reference type, in DBML syntax: + * `<` — one to many; + * `>` — many to one; + * `-` — one to one. +* **col1** (list of `Column`) — list of Column objects of the left side of the reference. Changed in **0.4.0**, previously was plain `Column`. +* **table1** (`Table` or `None`) — link to the left `Table` object of the reference or `None` of it was not set. +* **col2** (list of `Column`) — list of Column objects of the right side of the reference. Changed in **0.4.0**, previously was plain `Column`. +* **table2** (`Table` or `None`) — link to the right `Table` object of the reference or `None` of it was not set. +* **name** (str) — reference name, if defined. +* **on_update** (str) — reference's on update setting, if defined. +* **on_delete** (str) — reference's on delete setting, if defined. +* **comment** (str) — comment, if it was added before reference definition. +* **inline** (bool) — indicates whether this reference should be rendered inside SQL or DBML definition of the table. +* **sql** (str) — SQL definition for this reference. +* **dbml** (str) — DBML definition for this reference. + +## Enum + +`Enum` class represents a enum type in the database. + +```python +>>> from pydbml import PyDBML +>>> parsed = PyDBML.parse_file('test_schema.dbml') +>>> enum = parsed.enums[0] +>>> enum + + +``` + +Enum objects also act as a list of items: + +```python +>>> enum[0] + + +``` + +### Attributes + +database +name +schema +comment +items + +* **database** (`Database`) — link to the enum's database object, if it was set. +* **schema** (str) — enum schema name. +* **name** (str) — enum name, +* **items** (list of `EnumItem`) — list of items. +* **comment** (str) — comment, which was defined before enum definition. +* **sql** (str) — SQL definition for this enum. +* **dbml** (str) — DBML definition for this enum. + +### Methods + +* **add_item** (item: `EnumItem` or str) — add an item to this enum. + +### EnumItem + +`EnumItem` class represents an item of a enum type in the database. + +Enum items are stored in the `items` property of a `Enum` class. + +### Attributes + +* **name** (str) — enum item name, +* **note** (`Note`) — enum item note, if was defined. +* **comment** (str) — comment, which was defined before enum item definition or right after it on the same line. +* **sql** (str) — SQL definition for this enum item. +* **dbml** (str) — DBML definition for this enum item. + +## Note + +Note is a basic class, which may appear in some other classes' `note` attribute. Mainly used for documentation of a DBML database. + +### Attributes + +**text** (str) — note text. +* **sql** (str) — SQL definition for this note. +* **dbml** (str) — DBML definition for this note. + +## Note + +**new in PyDBML 1.0.10** + +Sticky notes are similar to regular notes, except that they are defined at the root of your DBML file and have a name. + +### Attributes + +**name** (str) — note name. +**text** (str) — note text. +* **dbml** (str) — DBML definition for this note. + +## Expression + +**new in PyDBML 1.0.0** + +`Expression` class represents an SQL expression. Expressions may appear in `Index` subjects or `Column` default values. + +### Attributes + +**text** (str) — expression text. +* **sql** (str) — SQL definition for this expression. +* **dbml** (str) — DBML definition for this expression. + +## Project + +`Project` class holds DBML project metadata. Project is not present in SQL. + +```python +>>> from pydbml import PyDBML +>>> parsed = PyDBML.parse_file('test_schema.dbml') +>>> parsed.project + + +``` + +### Attributes + +* **database** (`Database`) — link to the project's database object, if it was set. +* **name** (str) — project name, +* **items** (str) — dictionary with project metadata, +* **note** (`Note`) — note, if was defined, +* **comment** (str) — comment, if was added before project definition. +* **dbml** (str) — DBML definition for this project. + +## TableGroup + +`TableGroup` class represents a table group in the DBML database. TableGroups are not present in SQL. + +```python +>>> from pydbml import PyDBML +>>> parsed = PyDBML.parse_file('test_schema.dbml') +>>> parsed.table_groups +[, ] + +``` + +### Attributes + +* **database** (`Database`) — link to the tableg group's database object, if it was set. +* **name** (str) — table group name, +* **items** (str) — dictionary with tables in the group, +* **comment** (str) — comment, if was added before table group definition. +* **note** (Note) — table group's note if was defined. +* **color** (str) — the color param, if defined. +* **dbml** (str) — DBML definition for this table group. diff --git a/docs/creating_schema.md b/docs/creating_schema.md new file mode 100644 index 0000000..4b90dc9 --- /dev/null +++ b/docs/creating_schema.md @@ -0,0 +1,121 @@ +# Creating DBML schema + +You can use PyDBML not only for parsing DBML files, but also for creating schema from scratch in Python. + +## Database object + +You always start by creating a Database object. It will connect all other entities of the database for us. + +```python +>>> from pydbml import Database +>>> db = Database() + +``` + +Now let's create a table and add it to the database. + +```python +>>> from pydbml.classes import Table +>>> table1 = Table(name='products') +>>> db.add(table1) +
+ +``` + +To add columns to the table, you have to use the `add_column` method of the Table object. + +```python +>>> from pydbml.classes import Column +>>> col1 = Column(name='id', type='Integer', pk=True, autoinc=True) +>>> table1.add_column(col1) +>>> col2 = Column(name='product_name', type='Varchar', unique=True) +>>> table1.add_column(col2) +>>> col3 = Column(name='manufacturer_id', type='Integer') +>>> table1.add_column(col3) + +``` + +Index is also a part of a table, so you have to add it similarly, using `add_index` method: + +```python +>>> from pydbml.classes import Index +>>> index1 = Index([col2], unique=True) +>>> table1.add_index(index1) + +``` + +The table's third column, `manufacturer_id` looks like it should be a foreign key. Let's create another table, called `manufacturers`, so that we could create a relation. + +```python +>>> table2 = Table( +... 'manufacturers', +... columns=[ +... Column('id', type='Integer', pk=True, autoinc=True), +... Column('manufacturer_name', type='Varchar'), +... Column('manufacturer_country', type='Varchar') +... ] +... ) +>>> db.add(table2) +
+ +``` + +Now to the relation: + +```python +>>> from pydbml.classes import Reference +>>> ref = Reference('>', table1['manufacturer_id'], table2['id']) +>>> db.add(ref) +', ['manufacturer_id'], ['id']> + +``` + +You noticed that we are calling the `add` method on the Database after creating each object. While objects can somewhat function without being added to a database, DBML/SQL generation and some other useful methods won't work properly. + +Now let's generate DBML code for our schema. This is done by just calling the `dbml` property of the Database object: + +```python +>>> print(db.dbml) +Table "products" { + "id" Integer [pk, increment] + "product_name" Varchar [unique] + "manufacturer_id" Integer + + indexes { + product_name [unique] + } +} + +Table "manufacturers" { + "id" Integer [pk, increment] + "manufacturer_name" Varchar + "manufacturer_country" Varchar +} + +Ref { + "products"."manufacturer_id" > "manufacturers"."id" +} + +``` + +We can generate SQL for the schema similarly, by calling the `sql` property: + +```python +>>> print(db.sql) +CREATE TABLE "products" ( + "id" Integer PRIMARY KEY AUTOINCREMENT, + "product_name" Varchar UNIQUE, + "manufacturer_id" Integer +); + +CREATE UNIQUE INDEX ON "products" ("product_name"); + +CREATE TABLE "manufacturers" ( + "id" Integer PRIMARY KEY AUTOINCREMENT, + "manufacturer_name" Varchar, + "manufacturer_country" Varchar +); + +ALTER TABLE "products" ADD FOREIGN KEY ("manufacturer_id") REFERENCES "manufacturers" ("id"); + +``` diff --git a/docs/properties.md b/docs/properties.md new file mode 100644 index 0000000..f3f878a --- /dev/null +++ b/docs/properties.md @@ -0,0 +1,77 @@ +# Arbitrary Properties + +Since 1.1.0 PyDBML supports arbitrary properties in Table and Column definitions. Arbitrary properties is a dictionary of key-value pairs that can be added to any Table or Column manually, or parsed from a DBML file. This may be useful for extending the standard DBML syntax or keeping additional information in the schema. + +Arbitrary properties are turned off by default. To enable parsing properties in DBML files, set `allow_properties` argument to `True` in the parser call. To enable rendering properties in the output DBML of an existing database, set `allow_properties` database attribute to `True`. + +## Properties in DBML + +In a DBML file arbitrary properties are defined like this: + +```python +>>> dbml_str = ''' +... Table "products" { +... "id" integer +... "name" varchar [col_prop: 'some value'] +... table_prop: 'another value' +... }''' + +``` + +In this example we've added a property `col_prop` to the column `name` and a property `table_prop` to the table `products`. Note that property values must me single-quoted strings. Multiline strings (with `'''`) are supported. + +Now let's parse this DBML string: + +```python +>>> from pydbml import PyDBML +>>> mydb = PyDBML(dbml_str, allow_properties=True) +>>> mydb.tables[0].columns[1].properties +{'col_prop': 'some value'} +>>> mydb.tables[0].properties +{'table_prop': 'another value'} + +``` + +The `allow_properties=True` argument is crucial here. Without it, the parser will raise syntax errors. + +## Rendering Properties + +To render properties in the output DBML, set `allow_properties` attribute of the Database object to `True`. If you parsed the DBML with `allow_properties=True`, the result database will already have this attribute set to `True`. + +We will reuse the `mydb` database from the previous example: + +```python +>>> print(mydb.allow_properties) +True + +``` + +Let's set a new property on the table and render the DBML: + +```python +>>> mydb.tables[0].properties['new_prop'] = 'Multiline\nproperty\nvalue' +>>> print(mydb.dbml) +Table "products" { + "id" integer + "name" varchar [col_prop: 'some value'] + + table_prop: 'another value' + new_prop: ''' + Multiline + property + value''' +} + +``` + +As you see, properties are also rendered in the output DBML correctly. But if `allow_properties` is set to `False`, the properties will be ignored: + +```python +>>> mydb.allow_properties = False +>>> print(mydb.dbml) +Table "products" { + "id" integer + "name" varchar +} + +``` diff --git a/docs/upgrading.md b/docs/upgrading.md new file mode 100644 index 0000000..31ce7bc --- /dev/null +++ b/docs/upgrading.md @@ -0,0 +1,194 @@ +# Upgrading to PyDBML 1.0.0 + +When I created PyDBML back in April 2020, I just needed a DBML parser, and it was written as a parser. When people started using it, they wanted to also be able to edit DBML schema in Python and create it from scratch. While it worked to some extent, the project architecture was not completely ready for such usage. + +In May 2022 I've rewritten PyDBML from scratch and released version 1.0.0. Now you can not only parse DBML files, but also create them in Python and edit parsed schema. Sadly, it made the new version completely incompatible with the old one. This article will help you upgrade to PyDBML 1.0.0 and adapt your code to work with the new version. + +## Getting Tables From Parse Results by Name + +Previously the parser returned the `PyDBMLParseResults` object, now it returns a `Database` object. While they mostly can be operated similarly, now you can't get a table just by name. + +Since v2.4 DBML supports multiple schemas for tables and enums. PyDBML 1.0.0 also supports multiple schemas, but this means that there may be tables with the same name in different schemas. So now you can't get a table from the parse results just by name, you have to specify the schema too: + +```python +>>> from pydbml import PyDBML +>>> db = PyDBML.parse_file('test_schema.dbml') +>>> db['orders'] +Traceback (most recent call last): +... +KeyError: 'orders' +>>> db['public.orders'] +
+ +``` + +## New Table Object + +Previously the `Table` object had a `refs` attribute which holded a list of `TableReference` objects. `TableReference` represented a table relation and duplicated the `Reference` object of `PyDBMLParseResults` container. + +**In 1.0.0 the `TableReference` class is removed, and there's no `Table.refs` attribute.** + +Now each relation is represented by a single `Reference` object. You can still access `Table` references by calling the `get_refs` method. + +`Table.get_refs` will return a list of References for this table, but only if this table is on the left side of DBML relation. + +Here's an example DBML reference definition: + +```python +>>> source = ''' +... Table posts { +... id integer [primary key] +... user_id integer +... } +... +... Table users { +... id integer +... } +... +... Ref name_optional: posts.user_id > users.id +... ''' +>>> db = PyDBML(source) + +``` + +Here the many-to-one (`>`) relation is defined with the **posts** table on the left side, so calling `get_refs` on the **posts** table will return you this reference: + +```python +>>> db['public.posts'].get_refs() +[', ['user_id'], ['id']>] + +``` + +But calling `get_refs` on the **users** table won't give you the reference, because **users** is on the right side of the relation: + +```python +>>> db['public.users'].get_refs() +[] + +``` + +This depends on the side the table was referenced on, not on the type of the reference. So, if we modify the previous example to use one-to-many relation instead of many-to-one: + +```python +>>> source = ''' +... Table posts { +... id integer [primary key] +... user_id integer +... } +... +... Table users { +... id integer +... } +... +... Ref name_optional: users.id < posts.user_id +... ''' +>>> db = PyDBML(source) + +``` + +Now the **users** table is on the left, and we can only get the reference from the **users** table: + +```python +>>> db['public.users'].get_refs() +[] +>>> db['public.posts'].get_refs() +[] + +``` + +You can still get all the references for the database by accessing `Database.refs` property: + +```python +>>> db.refs +[] + +``` + +## New Reference Object + +Reference now can be explicitly inline. This is defined by the `Reference.inline` attribute. The `inline` attribute only affects how the reference will be rendered in table's SQL or DBML. + +Let's define an inline reference. + +```python +>>> from pydbml import Database +>>> from pydbml.classes import Table, Column, Reference +>>> db = Database() +>>> table1 = Table('products') +>>> db.add(table1) +
+>>> c1 = Column('name', 'varchar2') +>>> table1.add_column(c1) +>>> table2 = Table('names') +>>> db.add(table2) +
+>>> c2 = Column('name_val', 'varchar2') +>>> table2.add_column(c2) +>>> ref = Reference('>', c1, c2, inline=True) +>>> db.add(ref) +', ['name'], ['name_val']> +>>> print(table1.sql) +CREATE TABLE "products" ( + "name" varchar2, + FOREIGN KEY ("name") REFERENCES "names" ("name_val") +); + +``` + +If the reference is not inline, it won't appear in the Table SQL definition, otherwise it will be rendered separately as an `ALTER TABLE` clause: + +```python +>>> ref.inline = False +>>> print(table1.sql) +CREATE TABLE "products" ( + "name" varchar2 +); +>>> print(ref.sql) +ALTER TABLE "products" ADD FOREIGN KEY ("name") REFERENCES "names" ("name_val"); + +``` + +## `type_` -> `type` + +Previously you would initialize a `Column`, `Index` and `Reference` type with `type_` parameter. Now, this parameter is renamed to simply `type`. + +```python +>>> from pydbml.classes import Index, Column +>>> c = Column(name='name', type='varchar') +>>> c + +>>> t = Table('names') +>>> t.add_column(c) +>>> i = Index(subjects=[c], type='btree') +>>> t.add_index(i) +>>> i + +>>> t2 = Table('names_caps', columns=[Column('name_caps', 'varchar')]) +>>> ref = Reference(type='-', col1=t['name'], col2=t2['name_caps']) +>>> ref + + +``` + +## New Expression Class + +SQL expressions are allowed in column's `default` value definition and in index's subject definition. Previously, you defined expressions as parenthesized strings: `"(upper(name))"`. Now you have to use the `Expression` class. This will make sure the expression will be rendered properly in SQL and DBML. + +```python +>>> from pydbml.classes import Expression +>>> c = Column( +... name='upper_name', +... type='varchar', +... default=Expression('upper(name)') +... ) +>>> t = Table('names') +>>> t.add_column(c) +>>> db = Database() +>>> db.add(t) +
+>>> print(c.sql) +"upper_name" varchar DEFAULT (upper(name)) +>>> print(c.dbml) +"upper_name" varchar [default: `upper(name)`] + +``` diff --git a/pydbml/__init__.py b/pydbml/__init__.py index 879ea59..868de24 100644 --- a/pydbml/__init__.py +++ b/pydbml/__init__.py @@ -1,9 +1,15 @@ -from pydbml.parser import PyDBML, PyDBMLParseResults -import unittest -import doctest -from . import classes +import os +from . import _classes +from .parser import PyDBML +from .database import Database -def load_tests(loader, tests, ignore): - tests.addTests(doctest.DocTestSuite(classes)) - return tests \ No newline at end of file +load = PyDBML.parse_file +loads = PyDBML.parse + +def dump(db: Database, fp: str | os.PathLike): + with open(fp, 'w') as f: + f.write(db.dbml) + +def dumps(db: Database) -> str: + return db.dbml \ No newline at end of file diff --git a/pydbml/_classes/__init__.py b/pydbml/_classes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pydbml/_classes/base.py b/pydbml/_classes/base.py new file mode 100644 index 0000000..d1f781f --- /dev/null +++ b/pydbml/_classes/base.py @@ -0,0 +1,70 @@ +from typing import Any +from typing import Tuple + +from pydbml.exceptions import AttributeMissingError + + +class SQLObject: + ''' + Base class for all SQL objects. + ''' + required_attributes: Tuple[str, ...] = () + dont_compare_fields: Tuple[str, ...] = () + + def check_attributes_for_sql(self): + ''' + Check if all attributes, required for rendering SQL are set in the + instance. If some attribute is missing, raise AttributeMissingError + ''' + for attr in self.required_attributes: + if getattr(self, attr) is None: + raise AttributeMissingError( + f'Cannot render SQL. Missing required attribute "{attr}".' + ) + @property + def sql(self) -> str: + if hasattr(self, 'database') and self.database is not None: + renderer = self.database.sql_renderer + else: + from pydbml.renderer.sql.default import DefaultSQLRenderer + renderer = DefaultSQLRenderer + + return renderer.render(self) + + def __setattr__(self, name: str, value: Any): + """ + Required for type testing with MyPy. + """ + super().__setattr__(name, value) + + def __eq__(self, other: object) -> bool: + """ + Two instances of the same SQLObject subclass are equal if all their + attributes are equal. + """ + + if not isinstance(other, self.__class__): + return False + # not comparing those because they are circular references + + self_dict = dict(self.__dict__) + other_dict = dict(other.__dict__) + + for field in self.dont_compare_fields: + self_dict.pop(field, None) + other_dict.pop(field, None) + + return self_dict == other_dict + + +class DBMLObject: + '''Base class for all DBML objects.''' + @property + def dbml(self) -> str: + if hasattr(self, 'database') and self.database is not None: + renderer = self.database.dbml_renderer + else: + from pydbml.renderer.dbml.default import DefaultDBMLRenderer + renderer = DefaultDBMLRenderer + + return renderer.render(self) diff --git a/pydbml/_classes/column.py b/pydbml/_classes/column.py new file mode 100644 index 0000000..4868cd1 --- /dev/null +++ b/pydbml/_classes/column.py @@ -0,0 +1,94 @@ +from typing import List, Dict +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from pydbml.exceptions import TableNotFoundError +from .base import SQLObject, DBMLObject +from .enum import Enum +from .expression import Expression +from .note import Note + +if TYPE_CHECKING: # pragma: no cover + from .table import Table + from .reference import Reference + + +class Column(SQLObject, DBMLObject): + '''Class representing table column.''' + + required_attributes = ('name', 'type') + dont_compare_fields = ('table',) + + def __init__(self, + name: str, + type: Union[str, Enum], + unique: bool = False, + not_null: bool = False, + pk: bool = False, + autoinc: bool = False, + default: Optional[Union[str, int, bool, float, Expression]] = None, + note: Optional[Union[Note, str]] = None, + comment: Optional[str] = None, + properties: Union[Dict[str, str], None] = None + ): + self.name = name + self.type = type + self.unique = unique + self.not_null = not_null + self.pk = pk + self.autoinc = autoinc + self.comment = comment + self.note = Note(note) + self.properties = properties if properties else {} + + self.default = default + self.table: Optional['Table'] = None + + def __eq__(self, other: object) -> bool: + if other is self: + return True + if not isinstance(other, self.__class__): + return False + self_table = self.table.full_name if self.table else None + other_table = other.table.full_name if other.table else None + if self_table != other_table: + return False + return super().__eq__(other) + + @property + def note(self): + return self._note + + @note.setter + def note(self, val: Note) -> None: + self._note = val + val.parent = self + + def get_refs(self) -> List['Reference']: + ''' + get all references related to this column (where this col is col1 in) + ''' + if not self.table: + raise TableNotFoundError('Table for the column is not set') + return [ref for ref in self.table.get_refs() if self in ref.col1] + + @property + def database(self): + return self.table.database if self.table else None + + def __repr__(self): + ''' + >>> Column('name', 'VARCHAR2') + + ''' + type_name = self.type if isinstance(self.type, str) else self.type.name + return f'' + + def __str__(self): + ''' + >>> print(Column('name', 'VARCHAR2')) + name[VARCHAR2] + ''' + + return f'{self.name}[{self.type}]' diff --git a/pydbml/_classes/enum.py b/pydbml/_classes/enum.py new file mode 100644 index 0000000..0b2aee6 --- /dev/null +++ b/pydbml/_classes/enum.py @@ -0,0 +1,89 @@ +from typing import Iterable +from typing import List +from typing import Optional +from typing import Union + +from .base import SQLObject, DBMLObject +from .note import Note + + +class EnumItem(SQLObject, DBMLObject): + '''Single enum item''' + + required_attributes = ('name',) + + def __init__(self, + name: str, + note: Optional[Union[Note, str]] = None, + comment: Optional[str] = None): + self.name = name + self.note = Note(note) + self.comment = comment + + @property + def note(self): + return self._note + + @note.setter + def note(self, val: Note) -> None: + self._note = val + val.parent = self + + def __repr__(self): + '''''' + return f'' + + def __str__(self): + '''en-US''' + return self.name + + +class Enum(SQLObject, DBMLObject): + required_attributes = ('name', 'schema', 'items') + + def __init__(self, + name: str, + items: Iterable[Union['EnumItem', str]], + schema: str = 'public', + comment: Optional[str] = None): + self.database = None + self.name = name + self.schema = schema + self.comment = comment + self.items: List[EnumItem] = [] + for item in items: + self.add_item(item) + + def add_item(self, item: Union['EnumItem', str]) -> None: + if isinstance(item, EnumItem): + self.items.append(item) + elif isinstance(item, str): + self.items.append(EnumItem(item)) + + def __getitem__(self, key: int) -> EnumItem: + return self.items[key] + + def __iter__(self): + return iter(self.items) + + def __repr__(self): + ''' + >>> en = EnumItem('en-US') + >>> ru = EnumItem('ru-RU') + >>> Enum('languages', [en, ru]) + + ''' + + item_names = [i.name for i in self.items] + classname = self.__class__.__name__ + return f'<{classname} {self.name!r}, {item_names!r}>' + + def __str__(self): + ''' + >>> en = EnumItem('en-US') + >>> ru = EnumItem('ru-RU') + >>> print(Enum('languages', [en, ru])) + languages + ''' + + return self.name diff --git a/pydbml/_classes/expression.py b/pydbml/_classes/expression.py new file mode 100644 index 0000000..8b34445 --- /dev/null +++ b/pydbml/_classes/expression.py @@ -0,0 +1,22 @@ +from .base import SQLObject, DBMLObject + + +class Expression(SQLObject, DBMLObject): + def __init__(self, text: str): + self.text = text + + def __str__(self) -> str: + ''' + >>> print(Expression('sum(amount)')) + sum(amount) + ''' + + return self.text + + def __repr__(self) -> str: + ''' + >>> Expression('sum(amount)') + Expression('sum(amount)') + ''' + + return f'Expression({repr(self.text)})' diff --git a/pydbml/_classes/index.py b/pydbml/_classes/index.py new file mode 100644 index 0000000..0d6ee24 --- /dev/null +++ b/pydbml/_classes/index.py @@ -0,0 +1,80 @@ +from typing import List +from typing import Literal +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from .base import SQLObject, DBMLObject +from .column import Column +from .expression import Expression +from .note import Note + +if TYPE_CHECKING: # pragma: no cover + from .table import Table + + +class Index(SQLObject, DBMLObject): + '''Class representing index.''' + required_attributes = ('subjects', 'table') + dont_compare_fields = ('table',) + + def __init__(self, + subjects: List[Union[str, Column, Expression]], + name: Optional[str] = None, + unique: bool = False, + type: Optional[ + Literal[ + # https://www.postgresql.org/docs/current/indexes-types.html + "brin", + "btree", + "gin", + "gist", + "hash", + "spgist", + ] + ] = None, + pk: bool = False, + note: Optional[Union[Note, str]] = None, + comment: Optional[str] = None): + self.subjects = subjects + self.table: Optional[Table] = None + + self.name = name if name else None + self.unique = unique + self.type = type + self.pk = pk + self.note = Note(note) + self.comment = comment + + @property + def note(self): + return self._note + + @note.setter + def note(self, val: Note) -> None: + self._note = val + val.parent = self + + @property + def subject_names(self): + ''' + Returns updated list of subject names. + ''' + return [s.name if isinstance(s, Column) else str(s) for s in self.subjects] + + def __repr__(self): + ''' + + ''' + + table_name = self.table.name if self.table else None + return f"" + + def __str__(self): + ''' + Index(test[col, (c*2)]) + ''' + + table_name = self.table.name if self.table else '' + subjects = ', '.join(self.subject_names) + return f"Index({table_name}[{subjects}])" diff --git a/pydbml/_classes/note.py b/pydbml/_classes/note.py new file mode 100644 index 0000000..627eb07 --- /dev/null +++ b/pydbml/_classes/note.py @@ -0,0 +1,23 @@ +from typing import Any + +from .base import SQLObject, DBMLObject + + +class Note(SQLObject, DBMLObject): + dont_compare_fields = ('parent',) + + def __init__(self, text: Any) -> None: + self.text: str + self.text = str(text) if text is not None else '' + self.parent: Any = None + + def __str__(self): + '''Note text''' + return self.text + + def __bool__(self): + return bool(self.text) + + def __repr__(self): + '''Note('Note text')''' + return f'Note({repr(self.text)})' diff --git a/pydbml/_classes/project.py b/pydbml/_classes/project.py new file mode 100644 index 0000000..4e303c5 --- /dev/null +++ b/pydbml/_classes/project.py @@ -0,0 +1,34 @@ +from typing import Dict +from typing import Optional +from typing import Union + +from pydbml._classes.base import DBMLObject +from pydbml._classes.note import Note + + +class Project(DBMLObject): + dont_compare_fields = ('database',) + + def __init__(self, + name: str, + items: Optional[Dict[str, str]] = None, + note: Optional[Union[Note, str]] = None, + comment: Optional[str] = None): + self.database = None + self.name = name + self.items = items or {} + self.note = Note(note) + self.comment = comment + + def __repr__(self): + """""" + return f'' + + @property + def note(self): + return self._note + + @note.setter + def note(self, val: Note) -> None: + self._note = val + val.parent = self diff --git a/pydbml/_classes/reference.py b/pydbml/_classes/reference.py new file mode 100644 index 0000000..9da1e03 --- /dev/null +++ b/pydbml/_classes/reference.py @@ -0,0 +1,120 @@ +from itertools import chain +from typing import Collection +from typing import Literal +from typing import Optional +from typing import Union + +from pydbml.constants import MANY_TO_MANY +from pydbml.exceptions import DBMLError +from pydbml.exceptions import TableNotFoundError +from .base import SQLObject, DBMLObject +from .column import Column +from .table import Table + + +class Reference(SQLObject, DBMLObject): + ''' + Class, representing a foreign key constraint. + It is a separate object, which is not connected to Table or Column objects + and its `sql` property contains the ALTER TABLE clause. + ''' + required_attributes = ('type', 'col1', 'col2') + dont_compare_fields = ('database', '_inline') + + def __init__(self, + type: Literal['>', '<', '-', '<>'], + col1: Union[Column, Collection[Column]], + col2: Union[Column, Collection[Column]], + name: Optional[str] = None, + comment: Optional[str] = None, + on_update: Optional[str] = None, + on_delete: Optional[str] = None, + inline: bool = False): + self.database = None + self.type = type + self.col1 = [col1] if isinstance(col1, Column) else list(col1) + self.col2 = [col2] if isinstance(col2, Column) else list(col2) + self.name = name if name else None + self.comment = comment + self.on_update = on_update + self.on_delete = on_delete + self._inline = inline + + @property + def inline(self) -> bool: + return self._inline and not self.type == MANY_TO_MANY + + @inline.setter + def inline(self, val) -> None: + self._inline = val + + @property + def join_table(self) -> Optional[Table]: + if self.type != MANY_TO_MANY: + return None + + if self.table1 is None: + raise TableNotFoundError(f"Cannot generate join table for {self}: table 1 is unknown") + if self.table2 is None: + raise TableNotFoundError(f"Cannot generate join table for {self}: table 2 is unknown") + + return Table( + name=f'{self.table1.name}_{self.table2.name}', + schema=self.table1.schema, + columns=( + Column(name=f'{c.table.name}_{c.name}', type=c.type, not_null=True, pk=True) # type: ignore + for c in chain(self.col1, self.col2) + ), + abstract=True + ) + + @property + def table1(self) -> Optional[Table]: + self._validate() + return self.col1[0].table if self.col1 else None + + @property + def table2(self) -> Optional[Table]: + self._validate() + return self.col2[0].table if self.col2 else None + + def __repr__(self): + ''' + >>> c1 = Column('c1', 'int') + >>> c2 = Column('c2', 'int') + >>> Reference('>', col1=c1, col2=c2) + ', ['c1'], ['c2']> + >>> c12 = Column('c12', 'int') + >>> c22 = Column('c22', 'int') + >>> Reference('<', col1=[c1, c12], col2=(c2, c22)) + + ''' + + col1 = ', '.join(f'{c.name!r}' for c in self.col1) + col2 = ', '.join(f'{c.name!r}' for c in self.col2) + return f"" + + def __str__(self): + ''' + >>> c1 = Column('c1', 'int') + >>> c2 = Column('c2', 'int') + >>> print(Reference('>', col1=c1, col2=c2)) + Reference([c1] > [c2] + >>> c12 = Column('c12', 'int') + >>> c22 = Column('c22', 'int') + >>> print(Reference('<', col1=[c1, c12], col2=(c2, c22))) + Reference([c1, c12] < [c2, c22] + ''' + + col1 = ', '.join(f'{c.name}' for c in self.col1) + col2 = ', '.join(f'{c.name}' for c in self.col2) + return f"Reference([{col1}] {self.type} [{col2}]" + + def _validate(self): + table1 = self.col1[0].table + if any(c.table != table1 for c in self.col1): + raise DBMLError('Columns in col1 are from different tables') + + table2 = self.col2[0].table + if any(c.table != table2 for c in self.col2): + raise DBMLError('Columns in col2 are from different tables') diff --git a/pydbml/_classes/sticky_note.py b/pydbml/_classes/sticky_note.py new file mode 100644 index 0000000..e01e1b0 --- /dev/null +++ b/pydbml/_classes/sticky_note.py @@ -0,0 +1,24 @@ +from typing import Any + +from pydbml._classes.base import DBMLObject + + +class StickyNote(DBMLObject): + dont_compare_fields = ('database',) + + def __init__(self, name: str, text: Any) -> None: + self.name = name + self.text = str(text) if text is not None else '' + + self.database = None + + def __str__(self): + '''StickyNote('mynote', 'Note text')''' + return self.__class__.__name__ + f'({repr(self.name)}, {repr(self.text)})' + + def __bool__(self): + return bool(self.text) + + def __repr__(self): + '''''' + return f'<{self.__class__.__name__} {self.name!r}, {self.text!r}>' diff --git a/pydbml/_classes/table.py b/pydbml/_classes/table.py new file mode 100644 index 0000000..b8f340d --- /dev/null +++ b/pydbml/_classes/table.py @@ -0,0 +1,158 @@ +from typing import Iterable, Dict +from typing import List +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union + +from pydbml.exceptions import ColumnNotFoundError +from pydbml.exceptions import IndexNotFoundError +from pydbml.exceptions import UnknownDatabaseError +from .base import SQLObject, DBMLObject +from .column import Column +from .index import Index +from .note import Note + +if TYPE_CHECKING: # pragma: no cover + from pydbml.database import Database + from .reference import Reference + + +class Table(SQLObject, DBMLObject): + '''Class representing table.''' + + required_attributes = ('name', 'schema') + dont_compare_fields = ('database',) + + def __init__(self, + name: str, + schema: str = 'public', + alias: Optional[str] = None, + columns: Optional[Iterable[Column]] = None, + indexes: Optional[Iterable[Index]] = None, + note: Optional[Union[Note, str]] = None, + header_color: Optional[str] = None, + comment: Optional[str] = None, + abstract: bool = False, + properties: Union[Dict[str, str], None] = None + ): + self.database: Optional[Database] = None + self.name = name + self.schema = schema + self.columns: List[Column] = [] + for column in columns or []: + self.add_column(column) + self.indexes: List[Index] = [] + for index in indexes or []: + self.add_index(index) + self.alias = alias if alias else None + self.note = Note(note) + self.header_color = header_color + self.comment = comment + self.abstract = abstract + self.properties = properties if properties else {} + + @property + def note(self): + return self._note + + @note.setter + def note(self, val: Note) -> None: + self._note = val + val.parent = self + + @property + def full_name(self) -> str: + return f'{self.schema}.{self.name}' + + def _has_composite_pk(self) -> bool: + return sum(c.pk for c in self.columns) > 1 + + def add_column(self, c: Column) -> None: + ''' + Adds column to self.columns attribute and sets in this column the + `table` attribute. + ''' + if not isinstance(c, Column): + raise TypeError('Columns must be of type Column') + c.table = self + self.columns.append(c) + + def delete_column(self, c: Union[Column, int]) -> Column: + if isinstance(c, Column): + if c in self.columns: + c.table = None + return self.columns.pop(self.columns.index(c)) + else: + raise ColumnNotFoundError(f'Column {c} if missing in the table') + elif isinstance(c, int): + self.columns[c].table = None + return self.columns.pop(c) + + def add_index(self, i: Index) -> None: + ''' + Adds index to self.indexes attribute and sets in this index the + `table` attribute. + ''' + if not isinstance(i, Index): + raise TypeError('Indexes must be of type Index') + for subject in i.subjects: + if isinstance(subject, Column) and subject.table is not self: + raise ColumnNotFoundError(f'Column {subject} not in the table') + i.table = self + self.indexes.append(i) + + def delete_index(self, i: Union[Index, int]) -> Index: + if isinstance(i, Index): + if i in self.indexes: + i.table = None + return self.indexes.pop(self.indexes.index(i)) + else: + raise IndexNotFoundError(f'Index {i} if missing in the table') + elif isinstance(i, int): + self.indexes[i].table = None + return self.indexes.pop(i) + + def get_refs(self) -> List['Reference']: + if not self.database: + raise UnknownDatabaseError('Database for the table is not set') + return [ref for ref in self.database.refs if ref.table1 == self] + + def __getitem__(self, k: Union[int, str]) -> Column: + if isinstance(k, int): + return self.columns[k] + elif isinstance(k, str): + for c in self.columns: + if c.name == k: + return c + raise ColumnNotFoundError(f'Column {k} not present in table {self.name}') + else: + raise TypeError('indeces must be str or int') + + def get(self, k, default: Optional[Column] = None) -> Optional[Column]: + try: + return self.__getitem__(k) + except (IndexError, ColumnNotFoundError): + return default + + def __iter__(self): + return iter(self.columns) + + def __repr__(self): + ''' + >>> table = Table('customers') + >>> table +
+ ''' + + return f'
' + + def __str__(self): + ''' + >>> table = Table('customers') + >>> table.add_column(Column('id', 'INTEGER')) + >>> table.add_column(Column('name', 'VARCHAR2')) + >>> print(table) + public.customers(id, name) + ''' + + return f'{self.schema}.{self.name}({", ".join(c.name for c in self.columns)})' diff --git a/pydbml/_classes/table_group.py b/pydbml/_classes/table_group.py new file mode 100644 index 0000000..cd9abc6 --- /dev/null +++ b/pydbml/_classes/table_group.py @@ -0,0 +1,49 @@ +from typing import List +from typing import Optional + +from pydbml._classes.base import DBMLObject +from pydbml._classes.note import Note +from pydbml._classes.table import Table + + +class TableGroup(DBMLObject): + ''' + TableGroup `items` parameter initially holds just the names of the tables, + but after parsing the whole document, PyDBMLParseResults class replaces + them with references to actual tables. + ''' + dont_compare_fields = ('database',) + + def __init__(self, + name: str, + items: List[Table], + comment: Optional[str] = None, + note: Optional[Note] = None, + color: Optional[str] = None): + self.database = None + self.name = name + self.items = items + self.comment = comment + self.note = note + self.color = color + + def __repr__(self): + """ + >>> tg = TableGroup('mygroup', ['t1', 't2']) + >>> tg + + >>> t1 = Table('t1') + >>> t2 = Table('t2') + >>> tg.items = [t1, t2] + >>> tg + + """ + + items = [i if isinstance(i, str) else i.name for i in self.items] + return f'' + + def __getitem__(self, key: int) -> Table: + return self.items[key] + + def __iter__(self): + return iter(self.items) diff --git a/pydbml/classes.py b/pydbml/classes.py deleted file mode 100644 index 2b1294a..0000000 --- a/pydbml/classes.py +++ /dev/null @@ -1,799 +0,0 @@ -from __future__ import annotations - -from typing import Any -from typing import Dict -from typing import List -from typing import Collection -from typing import Optional -from typing import Tuple -from typing import Union - -from .exceptions import AttributeMissingError -from .exceptions import ColumnNotFoundError -from .exceptions import DuplicateReferenceError - - -class SQLOjbect: - ''' - Base class for all SQL objects. - ''' - required_attributes: Tuple[str, ...] = () - - def check_attributes_for_sql(self): - ''' - Check if all attributes, required for rendering SQL are set in the - instance. If some attribute is missing, raise AttributeMissingError - ''' - for attr in self.required_attributes: - if getattr(self, attr) is None: - raise AttributeMissingError( - f'Cannot render SQL. Missing required attribute "{attr}".' - ) - - def __setattr__(self, name: str, value: Any): - """ - Required for type testing with MyPy. - """ - super().__setattr__(name, value) - - def __eq__(self, other: object) -> bool: - """ - Two instances of the same SQLObject subclass are equal if all their - attributes are equal. - """ - - if isinstance(other, self.__class__): - return self.__dict__ == other.__dict__ - return False - - -class ReferenceBlueprint: - ''' - Intermediate class for references during parsing. Table and columns are just - strings at this point, as we can't check their validity until all schema - is parsed. - - Note: `table2` and `col2` params are technically required (left optional for aesthetics). - ''' - ONE_TO_MANY = '<' - MANY_TO_ONE = '>' - ONE_TO_ONE = '-' - - def __init__(self, - type_: str, - name: Optional[str] = None, - table1: Optional[str] = None, - col1: Optional[Union[str, Collection[str]]] = None, - table2: Optional[str] = None, - col2: Optional[Union[str, Collection[str]]] = None, - comment: Optional[str] = None, - on_update: Optional[str] = None, - on_delete: Optional[str] = None): - self.type = type_ - self.name = name if name else None - self.table1 = table1 if table1 else None - self.col1 = col1 if col1 else None - self.table2 = table2 if table2 else None - self.col2 = col2 if col2 else None - self.comment = comment - self.on_update = on_update - self.on_delete = on_delete - - def __repr__(self): - ''' - >>> ReferenceBlueprint('>', table1='t1', col1='c1', table2='t2', col2='c2') - ', 't1'.'c1', 't2'.'c2'> - >>> ReferenceBlueprint('<', table2='t2', col2='c2') - - >>> ReferenceBlueprint('>', table1='t1', col1=('c11', 'c12'), table2='t2', col2=['c21', 'c22']) - ', 't1'.('c11', 'c12'), 't2'.['c21', 'c22']> - ''' - - components = [f"' - - def __str__(self): - ''' - >>> r1 = ReferenceBlueprint('>', table1='t1', col1='c1', table2='t2', col2='c2') - >>> r2 = ReferenceBlueprint('<', table2='t2', col2='c2') - >>> r3 = ReferenceBlueprint('>', table1='t1', col1=('c11', 'c12'), table2='t2', col2=('c21', 'c22')) - >>> print(r1, r2) - ReferenceBlueprint(t1.c1 > t2.c2) ReferenceBlueprint(< t2.c2) - >>> print(r3) - ReferenceBlueprint(t1[c11, c12] > t2[c21, c22]) - ''' - - components = [f"ReferenceBlueprint("] - if self.table1: - components.append(self.table1) - if self.col1: - if isinstance(self.col1, str): - components.append(f'.{self.col1} ') - else: # list or tuple - components.append(f'[{", ".join(self.col1)}] ') - components.append(f'{self.type} ') - components.append(self.table2) - if isinstance(self.col2, str): - components.append(f'.{self.col2}') - else: # list or tuple - components.append(f'[{", ".join(self.col2)}]') - return ''.join(components) + ')' - - -class Reference(SQLOjbect): - ''' - Class, representing a foreign key constraint. - It is a separate object, which is not connected to Table or Column objects - and its `sql` property contains the ALTER TABLE clause. - ''' - required_attributes = ('type', 'table1', 'col1', 'table2', 'col2') - - ONE_TO_MANY = '<' - MANY_TO_ONE = '>' - ONE_TO_ONE = '-' - - def __init__(self, - type_: str, - table1: Table, - col1: Union[Column, Collection[Column]], - table2: Table, - col2: Union[Column, Collection[Column]], - name: Optional[str] = None, - comment: Optional[str] = None, - on_update: Optional[str] = None, - on_delete: Optional[str] = None): - self.type = type_ - self.table1 = table1 - self.col1 = [col1] if isinstance(col1, Column) else list(col1) - self.table2 = table2 - self.col2 = [col2] if isinstance(col2, Column) else list(col2) - self.name = name if name else None - self.comment = comment - self.on_update = on_update - self.on_delete = on_delete - - def __repr__(self): - ''' - >>> c1 = Column('c1', 'int') - >>> c2 = Column('c2', 'int') - >>> t1 = Table('t1') - >>> t2 = Table('t2') - >>> Reference('>', table1=t1, col1=c1, table2=t2, col2=c2) - ', 't1'.['c1'], 't2'.['c2']> - >>> c12 = Column('c12', 'int') - >>> c22 = Column('c22', 'int') - >>> Reference('<', table1=t1, col1=[c1, c12], table2=t2, col2=(c2, c22)) - - ''' - - components = [f"' - - def __str__(self): - ''' - >>> c1 = Column('c1', 'int') - >>> c2 = Column('c2', 'int') - >>> t1 = Table('t1') - >>> t2 = Table('t2') - >>> print(Reference('>', table1=t1, col1=c1, table2=t2, col2=c2)) - Reference(t1[c1] > t2[c2]) - >>> c12 = Column('c12', 'int') - >>> c22 = Column('c22', 'int') - >>> print(Reference('<', table1=t1, col1=[c1, c12], table2=t2, col2=(c2, c22))) - Reference(t1[c1, c12] < t2[c2, c22]) - ''' - - components = [f"Reference("] - components.append(self.table1.name) - components.append(f'[{", ".join(c.name for c in self.col1)}]') - components.append(f' {self.type} ') - components.append(self.table2.name) - components.append(f'[{", ".join(c.name for c in self.col2)}]') - return ''.join(components) + ')' - - @property - def sql(self): - ''' - Returns SQL of the reference: - - ALTER TABLE "orders" ADD FOREIGN KEY ("customer_id") REFERENCES "customers ("id"); - - ''' - self.check_attributes_for_sql() - c = f'CONSTRAINT "{self.name}" ' if self.name else '' - - if self.type in (self.MANY_TO_ONE, self.ONE_TO_ONE): - t1 = self.table1 - c1 = ', '.join(self.col1) - t2 = self.table2 - c2 = ', '.join(self.col2) - else: - t1 = self.table2 - c1 = ', '.join(self.col2) - t2 = self.table1 - c2 = ', '.join(self.col1) - - result = ( - f'ALTER TABLE "{t1.name}" ADD {c}FOREIGN KEY ("{c1.name}") ' - f'REFERENCES "{t2.name} ("{c2.name}")' - ) - if self.on_update: - result += f' ON UPDATE {self.on_update.upper()}' - if self.on_delete: - result += f' ON DELETE {self.on_delete.upper()}' - return result + ';' - - -class TableReference(SQLOjbect): - ''' - Class, representing a foreign key constraint. - This object should be assigned to the `refs` attribute of a Table object. - Its `sql` property contains the inline definition of the FOREIGN KEY clause. - ''' - required_attributes = ('col', 'ref_table', 'ref_col') - - def __init__(self, - col: Union[Column, List[Column]], - ref_table: Table, - ref_col: Union[Column, List[Column]], - name: Optional[str] = None, - on_delete: Optional[str] = None, - on_update: Optional[str] = None): - self.col = [col] if isinstance(col, Column) else list(col) - self.ref_table = ref_table - self.ref_col = [ref_col] if isinstance(ref_col, Column) else list(ref_col) - self.name = name - self.on_update = on_update - self.on_delete = on_delete - - def __repr__(self): - ''' - >>> c1 = Column('c1', 'int') - >>> c2 = Column('c2', 'int') - >>> t2 = Table('t2') - >>> TableReference(col=c1, ref_table=t2, ref_col=c2) - - >>> c12 = Column('c12', 'int') - >>> c22 = Column('c22', 'int') - >>> TableReference(col=[c1, c12], ref_table=t2, ref_col=(c2, c22)) - - ''' - - col_names = [c.name for c in self.col] - ref_col_names = [c.name for c in self.ref_col] - return f"" - - def __str__(self): - ''' - >>> c1 = Column('c1', 'int') - >>> c2 = Column('c2', 'int') - >>> t2 = Table('t2') - >>> print(TableReference(col=c1, ref_table=t2, ref_col=c2)) - TableReference([c1] > t2[c2]) - >>> c12 = Column('c12', 'int') - >>> c22 = Column('c22', 'int') - >>> print(TableReference(col=[c1, c12], ref_table=t2, ref_col=(c2, c22))) - TableReference([c1, c12] > t2[c2, c22]) - ''' - - components = [f"TableReference("] - components.append(f'[{", ".join(c.name for c in self.col)}]') - components.append(' > ') - components.append(self.ref_table.name) - components.append(f'[{", ".join(c.name for c in self.ref_col)}]') - return ''.join(components) + ')' - - @property - def sql(self): - ''' - Returns inline SQL of the reference, which should be a part of table definition: - - FOREIGN KEY ("order_id") REFERENCES "orders ("id") - - ''' - self.check_attributes_for_sql() - c = f'CONSTRAINT "{self.name}" ' if self.name else '' - cols = '", "'.join(c.name for c in self.col) - ref_cols = '", "'.join(c.name for c in self.ref_col) - result = ( - f'{c}FOREIGN KEY ("{cols}") ' - f'REFERENCES "{self.ref_table.name} ("{ref_cols}")' - ) - if self.on_update: - result += f' ON UPDATE {self.on_update.upper()}' - if self.on_delete: - result += f' ON DELETE {self.on_delete.upper()}' - return result - - -class Note: - def __init__(self, text: str): - self.text = text - - def __str__(self): - ''' - >>> print(Note('Note text')) - Note text - ''' - - return self.text - - def __bool__(self): - return bool(self.text) - - def __repr__(self): - ''' - >>> Note('Note text') - Note('Note text') - ''' - - return f'Note({repr(self.text)})' - - @property - def sql(self): - if self.text: - return '\n'.join(f'-- {line}' for line in self.text.split('\n')) - else: - return '' - - -class Column(SQLOjbect): - '''Class representing table column.''' - - required_attributes = ('name', 'type') - - def __init__(self, - name: str, - type_: str, - unique: bool = False, - not_null: bool = False, - pk: bool = False, - autoinc: bool = False, - default: Optional[Union[str, int, bool, float]] = None, - note: Optional[Note] = None, - ref_blueprints: Optional[List[ReferenceBlueprint]] = None, - comment: Optional[str] = None): - self.name = name - self.type = type_ - self.unique = unique - self.not_null = not_null - self.pk = pk - self.autoinc = autoinc - self.comment = comment - - self.default = default - - self.note = note or Note('') - self.ref_blueprints = ref_blueprints or [] - for ref in self.ref_blueprints: - ref.col1 = self.name - - self._table: Optional[Table] = None - - @property - def table(self) -> Optional[Table]: - return self._table - - @table.setter - def table(self, v: Table): - self._table = v - for ref in self.ref_blueprints: - ref.table1 = v.name - - @property - def sql(self): - ''' - Returns inline SQL of the column, which should be a part of table definition: - - "id" integer PRIMARY KEY AUTOINCREMENT - ''' - - self.check_attributes_for_sql() - components = [f'"{self.name}"', str(self.type)] - if self.pk: - components.append('PRIMARY KEY') - if self.autoinc: - components.append('AUTOINCREMENT') - if self.unique: - components.append('UNIQUE') - if self.not_null: - components.append('NOT NULL') - if self.default is not None: - components.append('DEFAULT ' + str(self.default)) - if self.note: - components.append(self.note.sql) - return ' '.join(components) - - def __repr__(self): - ''' - >>> Column('name', 'VARCHAR2') - - ''' - type_name = self.type if isinstance(self.type, str) else self.type.name - return f'' - - def __str__(self): - ''' - >>> print(Column('name', 'VARCHAR2')) - name[VARCHAR2] - ''' - - return f'{self.name}[{self.type}]' - - -class Index(SQLOjbect): - '''Class representing index.''' - required_attributes = ('subjects', 'table') - - def __init__(self, - subject_names: List[str], - name: Optional[str] = None, - table: Optional[Table] = None, - unique: bool = False, - type_: Optional[str] = None, - pk: bool = False, - note: Optional[Note] = None, - comment: Optional[str] = None): - self.subject_names = subject_names - self.subjects: List[Union[Column, str]] = [] - - self.name = name if name else None - self.table = table - self.unique = unique - self.type = type_ - self.pk = pk - self.note = note or Note('') - self.comment = comment - - def __repr__(self): - ''' - >>> Index(['name', 'type']) - - >>> t = Table('t') - >>> Index(['name', 'type'], table=t) - - ''' - - table_name = self.table.name if self.table else None - return f"" - - - def __str__(self): - ''' - >>> print(Index(['name', 'type'])) - Index([name, type]) - >>> t = Table('t') - >>> print(Index(['name', 'type'], table=t)) - Index(t[name, type]) - ''' - - table_name = self.table.name if self.table else '' - subjects = ', '.join(s for s in self.subject_names) - return f"Index({table_name}[{subjects}])" - - @property - def sql(self): - ''' - Returns inline SQL of the index to be created separately from table - definition: - - CREATE UNIQUE INDEX ON "products" USING HASH ("id"); - - But if it's a (composite) primary key index, returns an inline SQL for - composite primary key to be used inside table definition: - - PRIMARY KEY ("id", "name") - - ''' - self.check_attributes_for_sql() - keys = ', '.join(f'"{key.name}"' if isinstance(key, Column) else key for key in self.subjects) - if self.pk: - return f'PRIMARY KEY ({keys})' - - components = ['CREATE'] - if self.unique: - components.append('UNIQUE') - components.append('INDEX') - if self.name: - components.append(f'"{self.name}"') - components.append(f'ON "{self.table.name}"') - if self.type: - components.append(f'USING {self.type.upper()}') - components.append(f'({keys})') - result = ' '.join(components) + ';' - if self.note: - result += f' {self.note.sql}' - return result - - -class Table(SQLOjbect): - '''Class representing table.''' - - required_attributes = ('name',) - - def __init__(self, - name: str, - alias: Optional[str] = None, - note: Optional[Note] = None, - header_color: Optional[str] = None, - refs: Optional[List[TableReference]] = None, - comment: Optional[str] = None): - self.name = name - self.columns: List[Column] = [] - self.indexes: List[Index] = [] - self.column_dict: Dict[str, Column] = {} - self.alias = alias if alias else None - self.note = note or Note('') - self.header_color = header_color - self.refs = refs or [] - self.comment = comment - - def add_column(self, c: Column) -> None: - ''' - Adds column to self.columns attribute and sets in this column the - `table` attribute. - ''' - c.table = self - self.columns.append(c) - self.column_dict[c.name] = c - - def add_index(self, i: Index) -> None: - ''' - Adds index to self.indexes attribute and sets in this index the - `table` attribute. - ''' - for subj in i.subject_names: - if subj.startswith('(') and subj.endswith(')'): - # subject is an expression, add it as string - i.subjects.append(subj) - else: - try: - col = self[subj] - i.subjects.append(col) - except KeyError: - raise ColumnNotFoundError(f'Cannot add index, column "{subj}" not defined in table "{self.name}".') - - i.table = self - self.indexes.append(i) - - def add_ref(self, r: TableReference) -> None: - ''' - Adds a reference to the table. If reference already present in the table, - raises DuplicateReferenceError. - ''' - if r in self.refs: - raise DuplicateReferenceError(f'Reference with same endpoints {r} already present in the table.') - self.refs.append(r) - - def __getitem__(self, k: Union[int, str]) -> Column: - if isinstance(k, int): - return self.columns[k] - else: - return self.column_dict[k] - - def get(self, k, default=None): - return self.column_dict.get(k, default) - - def __iter__(self): - return iter(self.columns) - - def __repr__(self): - ''' - >>> table = Table('customers') - >>> table -
- ''' - - return f'
' - - def __str__(self): - ''' - >>> table = Table('customers') - >>> table.add_column(Column('id', 'INTEGER')) - >>> table.add_column(Column('name', 'VARCHAR2')) - >>> print(table) - customers(id, name) - ''' - - return f'{self.name}({", ".join(c.name for c in self.columns)})' - - @property - def sql(self): - ''' - Returns full SQL for table definition: - - CREATE TABLE "countries" ( - "code" int PRIMARY KEY, - "name" varchar, - "continent_name" varchar - ); - - Also returns indexes if they were defined: - - CREATE INDEX ON "products" ("id", "name"); - ''' - self.check_attributes_for_sql() - components = [f'CREATE TABLE "{self.name}" ('] - if self.note: - components.append(f' {self.note.sql}') - body = [] - body.extend(' ' + c.sql for c in self.columns) - body.extend(' ' + i.sql for i in self.indexes if i.pk) - body.extend(' ' + r.sql for r in self.refs) - components.append(',\n'.join(body)) - components.append(');\n') - components.extend(i.sql + '\n' for i in self.indexes if not i.pk) - return '\n'.join(components) - - -class EnumItem: - '''Single enum item. Does not translate into SQL''' - - def __init__(self, - name: str, - note: Optional[Note] = None, - comment: Optional[str] = None): - self.name = name - self.note = note or Note('') - self.comment = comment - - def __repr__(self): - ''' - >>> EnumItem('en-US') - - ''' - - return f'' - - def __str__(self): - ''' - >>> print(EnumItem('en-US')) - en-US - ''' - - return self.name - - @property - def sql(self): - components = [f"'{self.name}',"] - if self.note: - components.append(self.note.sql) - return ' '.join(components) - - -class Enum(SQLOjbect): - required_attributes = ('name', 'items') - - def __init__(self, - name: str, - items: List[EnumItem], - comment: Optional[str] = None): - self.name = name - self.items = items - self.comment = comment - - def get_type(self): - return EnumType(self.name, self.items) - - def __getitem__(self, key) -> EnumItem: - return self.items[key] - - def __iter__(self): - return iter(self.items) - - def __repr__(self): - ''' - >>> en = EnumItem('en-US') - >>> ru = EnumItem('ru-RU') - >>> Enum('languages', [en, ru]) - - ''' - - item_names = [i.name for i in self.items] - classname = self.__class__.__name__ - return f'<{classname} {self.name!r}, {item_names!r}>' - - def __str__(self): - ''' - >>> en = EnumItem('en-US') - >>> ru = EnumItem('ru-RU') - >>> print(Enum('languages', [en, ru])) - languages - ''' - - return self.name - - @property - def sql(self): - ''' - Returns SQL for enum type: - - CREATE TYPE "job_status" AS ENUM ( - 'created', - 'running', - 'donef', - 'failure', - ); - - ''' - self.check_attributes_for_sql() - return f'CREATE TYPE "{self.name}" AS ENUM (\n' +\ - '\n'.join(f' {i.sql}' for i in self.items) +\ - '\n);' - - -class EnumType(Enum): - ''' - Enum object, intended to be put in the `type` attribute of a column. - - >>> en = EnumItem('en-US') - >>> ru = EnumItem('ru-RU') - >>> EnumType('languages', [en, ru]) - - >>> print(_) - languages - ''' - - pass - - -class TableGroup: - ''' - TableGroup `items` parameter initially holds just the names of the tables, - but after parsing the whole document, PyDBMLParseResults class replaces - them with references to actual tables. - ''' - - def __init__(self, - name: str, - items: Union[List[str], List[Table]], - comment: Optional[str] = None): - self.name = name - self.items = items - self.comment = comment - - def __repr__(self): - """ - >>> tg = TableGroup('mygroup', ['t1', 't2']) - >>> tg - - >>> t1 = Table('t1') - >>> t2 = Table('t2') - >>> tg.items = [t1, t2] - >>> tg - - """ - - items = [i if isinstance(i, str) else i.name for i in self.items] - return f'' - - def __getitem__(self, key) -> str: - return self.items[key] - - def __iter__(self): - return iter(self.items) - - -class Project: - def __init__(self, - name: str, - items: Optional[Dict[str, str]] = None, - note: Optional[Note] = None, - comment: Optional[str] = None): - self.name = name - self.items = items - self.note = note or Note('') - self.comment = comment - - def __repr__(self): - """ - >>> Project('myproject') - - """ - - return f'' diff --git a/pydbml/classes/__init__.py b/pydbml/classes/__init__.py new file mode 100644 index 0000000..a083781 --- /dev/null +++ b/pydbml/classes/__init__.py @@ -0,0 +1,25 @@ +from .._classes.column import Column +from .._classes.enum import Enum +from .._classes.enum import EnumItem +from .._classes.expression import Expression +from .._classes.index import Index +from .._classes.note import Note +from .._classes.project import Project +from .._classes.reference import Reference +from .._classes.sticky_note import StickyNote +from .._classes.table import Table +from .._classes.table_group import TableGroup + +__all__ = [ + "Column", + "Enum", + "EnumItem", + "Expression", + "Index", + "Note", + "Project", + "Reference", + "StickyNote", + "Table", + "TableGroup", +] diff --git a/pydbml/constants.py b/pydbml/constants.py new file mode 100644 index 0000000..712ac61 --- /dev/null +++ b/pydbml/constants.py @@ -0,0 +1,4 @@ +ONE_TO_MANY = '<' +MANY_TO_ONE = '>' +ONE_TO_ONE = '-' +MANY_TO_MANY = '<>' diff --git a/pydbml/database.py b/pydbml/database.py new file mode 100644 index 0000000..41dd589 --- /dev/null +++ b/pydbml/database.py @@ -0,0 +1,207 @@ +from typing import Any, Type +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from ._classes.sticky_note import StickyNote +from .classes import Enum +from .classes import Project +from .classes import Reference +from .classes import Table +from .classes import TableGroup +from .exceptions import DatabaseValidationError +from .renderer.base import BaseRenderer +from .renderer.dbml.default.renderer import DefaultDBMLRenderer +from .renderer.sql.default import DefaultSQLRenderer + + +class Database: + def __init__( + self, + sql_renderer: Type[BaseRenderer] = DefaultSQLRenderer, + dbml_renderer: Type[BaseRenderer] = DefaultDBMLRenderer, + allow_properties: bool = False + ) -> None: + self.sql_renderer = sql_renderer + self.dbml_renderer = dbml_renderer + self.tables: List['Table'] = [] + self.table_dict: Dict[str, 'Table'] = {} + self.refs: List['Reference'] = [] + self.enums: List['Enum'] = [] + self.table_groups: List['TableGroup'] = [] + self.sticky_notes: List['StickyNote'] = [] + self.project: Optional['Project'] = None + self.allow_properties = allow_properties + + def __repr__(self) -> str: + return f"" + + def __getitem__(self, k: Union[int, str]) -> Table: + if isinstance(k, int): + return self.tables[k] + elif isinstance(k, str): + return self.table_dict[k] + else: + raise TypeError('indeces must be str or int') + + def __iter__(self): + return iter(self.tables) + + def _set_database(self, obj: Any) -> None: + obj.database = self + + def _unset_database(self, obj: Any) -> None: + obj.database = None + + def add(self, obj: Any) -> Any: + if isinstance(obj, Table): + return self.add_table(obj) + elif isinstance(obj, Reference): + return self.add_reference(obj) + elif isinstance(obj, Enum): + return self.add_enum(obj) + elif isinstance(obj, TableGroup): + return self.add_table_group(obj) + elif isinstance(obj, Project): + return self.add_project(obj) + elif isinstance(obj, StickyNote): + return self.add_sticky_note(obj) + else: + raise DatabaseValidationError(f'Unsupported type {type(obj)}.') + + def add_table(self, obj: Table) -> Table: + if obj in self.tables: + raise DatabaseValidationError(f'{obj} is already in the database.') + if obj.full_name in self.table_dict: + raise DatabaseValidationError(f'Table {obj.full_name} is already in the database.') + if obj.alias and obj.alias in self.table_dict: + raise DatabaseValidationError(f'Table {obj.alias} is already in the database.') + + self._set_database(obj) + + self.tables.append(obj) + self.table_dict[obj.full_name] = obj + if obj.alias: + self.table_dict[obj.alias] = obj + return obj + + def add_reference(self, obj: Reference): + for col in (*obj.col1, *obj.col2): + if col.table and col.table.database == self: + break + else: + raise DatabaseValidationError( + 'Cannot add reference. At least one of the referenced tables' + ' should belong to this database' + ) + if obj in self.refs: + raise DatabaseValidationError(f'{obj} is already in the database.') + + self._set_database(obj) + self.refs.append(obj) + return obj + + def add_enum(self, obj: Enum) -> Enum: + if obj in self.enums: + raise DatabaseValidationError(f'{obj} is already in the database.') + for enum in self.enums: + if enum.name == obj.name and enum.schema == obj.schema: + raise DatabaseValidationError(f'Enum {obj.schema}.{obj.name} is already in the database.') + + self._set_database(obj) + self.enums.append(obj) + return obj + + def add_sticky_note(self, obj: StickyNote) -> StickyNote: + self._set_database(obj) + self.sticky_notes.append(obj) + return obj + + def add_table_group(self, obj: TableGroup) -> TableGroup: + if obj in self.table_groups: + raise DatabaseValidationError(f'{obj} is already in the database.') + for table_group in self.table_groups: + if table_group.name == obj.name: + raise DatabaseValidationError(f'TableGroup {obj.name} is already in the database.') + + self._set_database(obj) + self.table_groups.append(obj) + return obj + + def add_project(self, obj: Project) -> Project: + if self.project: + self.delete_project() + self._set_database(obj) + self.project = obj + return obj + + def delete(self, obj: Any) -> Any: + if isinstance(obj, Table): + return self.delete_table(obj) + elif isinstance(obj, Reference): + return self.delete_reference(obj) + elif isinstance(obj, Enum): + return self.delete_enum(obj) + elif isinstance(obj, TableGroup): + return self.delete_table_group(obj) + elif isinstance(obj, Project): + return self.delete_project() + else: + raise DatabaseValidationError(f'Unsupported type {type(obj)}.') + + def delete_table(self, obj: Table) -> Table: + try: + index = self.tables.index(obj) + except ValueError: + raise DatabaseValidationError(f'{obj} is not in the database.') + self._unset_database(self.tables.pop(index)) + result = self.table_dict.pop(obj.full_name) + if obj.alias: + self.table_dict.pop(obj.alias) + return result + + def delete_reference(self, obj: Reference) -> Reference: + try: + index = self.refs.index(obj) + except ValueError: + raise DatabaseValidationError(f'{obj} is not in the database.') + result = self.refs.pop(index) + self._unset_database(result) + return result + + def delete_enum(self, obj: Enum) -> Enum: + try: + index = self.enums.index(obj) + except ValueError: + raise DatabaseValidationError(f'{obj} is not in the database.') + result = self.enums.pop(index) + self._unset_database(result) + return result + + def delete_table_group(self, obj: TableGroup) -> TableGroup: + try: + index = self.table_groups.index(obj) + except ValueError: + raise DatabaseValidationError(f'{obj} is not in the database.') + result = self.table_groups.pop(index) + self._unset_database(result) + return result + + def delete_project(self) -> Project: + if self.project is None: + raise DatabaseValidationError(f'Project is not set.') + result = self.project + self.project = None + self._unset_database(result) + return result + + @property + def sql(self): + '''Returs SQL of the parsed results''' + return self.sql_renderer.render_db(self) + + @property + def dbml(self): + '''Generates DBML code out of parsed results''' + return self.dbml_renderer.render_db(self) diff --git a/pydbml/definitions/column.py b/pydbml/definitions/column.py index 8dec3be..1cd37d0 100644 --- a/pydbml/definitions/column.py +++ b/pydbml/definitions/column.py @@ -1,7 +1,5 @@ import pyparsing as pp -from pydbml.classes import Column - from .common import _ from .common import _c from .common import c @@ -16,50 +14,39 @@ from .generic import number_literal from .generic import string_literal from .reference import ref_inline +from pydbml.parser.blueprints import ColumnBlueprint -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') - -type_args = ("(" + pp.originalTextFor(expression)('args') + ")") -type_name = (pp.Word(pp.alphanums + '_') | pp.dblQuotedString())('name') -column_type = (type_name + type_args[0, 1]) +pp.ParserElement.set_default_whitespace_chars(' \t\r') +type_args = ("(" + pp.original_text_for(expression) + ")") -def parse_column_type(s, l, t): - ''' - int or "mytype" or varchar(255) - ''' - result = t['name'] - args = t.get('args') - result += '(' + args + ')' if args else '' - return result - - -column_type.setParseAction(parse_column_type) - +# column type is parsed as a single string, it will be split by blueprint +column_type = pp.Combine((name + pp.Literal('[]')) | (name + '.' + name) | ((name) + type_args[0, 1])) default = pp.CaselessLiteral('default:').suppress() + _ - ( string_literal | expression_literal - | boolean_literal.setParseAction( - lambda s, l, t: { + | boolean_literal.set_parse_action( + lambda s, loc, tok: { 'true': True, 'false': False, 'NULL': None - }[t[0]] + }[tok[0]] ) - | number_literal.setParseAction( - lambda s, l, t: float(''.join(t[0])) if '.' in t[0] else int(t[0]) + | number_literal.set_parse_action( + lambda s, loc, tok: float(''.join(tok[0])) if '.' in tok[0] else int(tok[0]) ) ) +prop = name + pp.Suppress(":") + string_literal column_setting = _ + ( - pp.CaselessLiteral("not null").setParseAction( - lambda s, l, t: True + pp.CaselessLiteral("not null").set_parse_action( + lambda s, loc, tok: True )('notnull') - | pp.CaselessLiteral("null").setParseAction( - lambda s, l, t: False + | pp.CaselessLiteral("null").set_parse_action( + lambda s, loc, tok: False )('notnull') | pp.CaselessLiteral("primary key")('pk') | pk('pk') @@ -69,34 +56,42 @@ def parse_column_type(s, l, t): | ref_inline('ref*') | default('default') ) + _ + +column_setting_with_property = column_setting | prop.set_results_name('property', list_all_matches=True) + column_settings = '[' - column_setting + ("," + column_setting)[...] + ']' + c +column_settings_with_properties = '[' - (_ + column_setting_with_property + _) + ("," + column_setting_with_property)[...] + ']' + c + -def parse_column_settings(s, l, t): +def parse_column_settings(s, loc, tok): ''' [ NOT NULL, increment, default: `now()`] ''' result = {} - if t.get('notnull'): + if tok.get('notnull'): result['not_null'] = True - if 'pk' in t: + if 'pk' in tok: result['pk'] = True - if 'unique' in t: + if 'unique' in tok: result['unique'] = True - if 'increment' in t: + if 'increment' in tok: result['autoinc'] = True - if 'note' in t: - result['note'] = t['note'] - if 'default' in t: - result['default'] = t['default'][0] - if 'ref' in t: - result['ref_blueprints'] = list(t['ref']) - if 'comment' in t: - result['comment'] = t['comment'][0] + if 'note' in tok: + result['note'] = tok['note'] + if 'default' in tok: + result['default'] = tok['default'][0] + if 'ref' in tok: + result['ref_blueprints'] = list(tok['ref']) + if 'comment' in tok: + result['comment'] = tok['comment'][0] + if 'property' in tok: + result['properties'] = {k: v for k, v in tok['property']} return result -column_settings.setParseAction(parse_column_settings) +column_settings.set_parse_action(parse_column_settings) +column_settings_with_properties.set_parse_action(parse_column_settings) constraint = pp.CaselessLiteral("unique") | pp.CaselessLiteral("pk") @@ -109,32 +104,41 @@ def parse_column_settings(s, l, t): ) + n -def parse_column(s, l, t): +table_column_with_properties = _c + ( + name('name') + + column_type('type') + + constraint[...]('constraints') + c + + column_settings_with_properties('settings')[0, 1] +) + n + + +def parse_column(s, loc, tok): ''' address varchar(255) [unique, not null, note: 'to include unit number'] ''' init_dict = { - 'name': t['name'], - 'type_': t['type'], + 'name': tok['name'], + 'type': tok['type'], } # deprecated - for constraint in t.get('constraints', []): + for constraint in tok.get('constraints', []): if constraint == 'pk': init_dict['pk'] = True elif constraint == 'unique': init_dict['unique'] = True - if 'settings' in t: - init_dict.update(t['settings']) + if 'settings' in tok: + init_dict.update(tok['settings']) # comments after column definition have priority - if 'comment' in t: - init_dict['comment'] = t['comment'][0] - if 'comment' not in init_dict and 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + if 'comment' in tok: + init_dict['comment'] = tok['comment'][0] + if 'comment' not in init_dict and 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment - return Column(**init_dict) + return ColumnBlueprint(**init_dict) -table_column.setParseAction(parse_column) +table_column.set_parse_action(parse_column) +table_column_with_properties.set_parse_action(parse_column) diff --git a/pydbml/definitions/common.py b/pydbml/definitions/common.py index 23d874f..8288de3 100644 --- a/pydbml/definitions/common.py +++ b/pydbml/definitions/common.py @@ -1,12 +1,14 @@ import pyparsing as pp -from pydbml.classes import Note - from .generic import string_literal +from pydbml.parser.blueprints import NoteBlueprint -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') -comment = pp.Suppress("//") + pp.SkipTo(pp.LineEnd()) +comment = ( + pp.Suppress("//") + pp.SkipTo(pp.LineEnd()) + | pp.Suppress('/*') + ... + pp.Suppress('*/') +) # optional comment or newline _ = ('\n' | comment)[...].suppress() @@ -18,15 +20,20 @@ c = comment('comment')[0, 1] n = pp.LineEnd() -end = n | pp.StringEnd() + +end = comment[...].suppress() + n | pp.StringEnd() + # obligatory newline # n = pp.Suppress('\n')[1, ...] note = pp.CaselessLiteral("note:") + _ - string_literal('text') -note.setParseAction(lambda s, l, t: Note(t['text'])) +note.set_parse_action(lambda s, loc, tok: NoteBlueprint(tok['text'])) note_object = pp.CaselessLiteral('note') + _ - '{' + _ - string_literal('text') + _ - '}' -note_object.setParseAction(lambda s, l, t: Note(t['text'])) +note_object.set_parse_action(lambda s, loc, tok: NoteBlueprint(tok['text'])) pk = pp.CaselessLiteral("pk") unique = pp.CaselessLiteral("unique") + +hex_char = pp.Word(pp.srange('[0-9a-fA-F]'), exact=1) +hex_color = ("#" - (hex_char * 3 ^ hex_char * 6)).leaveWhitespace() diff --git a/pydbml/definitions/enum.py b/pydbml/definitions/enum.py index d8bc323..0c6d6fa 100644 --- a/pydbml/definitions/enum.py +++ b/pydbml/definitions/enum.py @@ -1,8 +1,5 @@ import pyparsing as pp -from pydbml.classes import Enum -from pydbml.classes import EnumItem - from .common import _ from .common import _c from .common import c @@ -10,61 +7,64 @@ from .common import n from .common import note from .generic import name +from pydbml.parser.blueprints import EnumBlueprint +from pydbml.parser.blueprints import EnumItemBlueprint -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') enum_settings = '[' + _ - note('note') + _ - ']' + c -def parse_enum_settings(s, l, t): +def parse_enum_settings(s, loc, tok): ''' [note: "note content"] // comment ''' result = {} - if 'note' in t: - result['note'] = t['note'] - if 'comment' in t: - result['comment'] = t['comment'][0] + if 'note' in tok: + result['note'] = tok['note'] + if 'comment' in tok: + result['comment'] = tok['comment'][0] return result -enum_settings.setParseAction(parse_enum_settings) +enum_settings.set_parse_action(parse_enum_settings) enum_item = _c + (name('name') + c + enum_settings('settings')[0, 1]) -def parse_enum_item(s, l, t): +def parse_enum_item(s, loc, tok): ''' student [note: "is stupid"] ''' - init_dict = {'name': t['name']} - if 'settings' in t: - init_dict.update(t['settings']) - - # comments after settings have priority - if 'comment' in t: - init_dict['comment'] = t['comment'][0] - if 'comment' not in init_dict and 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + init_dict = {'name': tok['name']} + if 'settings' in tok: + init_dict.update(tok['settings']) + # comments after settings have priority + if 'comment' in tok['settings']: + init_dict['comment'] = tok['settings']['comment'] + if 'comment' not in init_dict and 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment - return EnumItem(**init_dict) + return EnumItemBlueprint(**init_dict) -enum_item.setParseAction(parse_enum_item) +enum_item.set_parse_action(parse_enum_item) enum_body = enum_item[1, ...] +enum_name = pp.Combine(name("schema") + '.' + name("name")) | name("name") + enum = _c + ( pp.CaselessLiteral('enum') - - name('name') + _ + - enum_name + _ - '{' + enum_body('items') + n - '}' ) + end -def parse_enum(s, l, t): +def parse_enum(s, loc, tok): ''' enum members { janitor @@ -74,15 +74,18 @@ def parse_enum(s, l, t): } ''' init_dict = { - 'name': t['name'], - 'items': list(t['items']) + 'name': tok['name'], + 'items': list(tok['items']) } - if 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + if 'schema' in tok: + init_dict['schema'] = tok['schema'] + + if 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment - return Enum(**init_dict) + return EnumBlueprint(**init_dict) -enum.setParseAction(parse_enum) +enum.set_parse_action(parse_enum) diff --git a/pydbml/definitions/generic.py b/pydbml/definitions/generic.py index 1bb5650..69c1270 100644 --- a/pydbml/definitions/generic.py +++ b/pydbml/definitions/generic.py @@ -1,6 +1,8 @@ import pyparsing as pp -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +from pydbml.parser.blueprints import ExpressionBlueprint + +pp.ParserElement.set_default_whitespace_chars(' \t\r') name = pp.Word(pp.alphanums + '_') | pp.QuotedString('"') @@ -15,7 +17,7 @@ pp.Suppress('`') + pp.CharsNotIn('`')[...] + pp.Suppress('`') -).setParseAction(lambda s, l, t: f'({t[0]})') +).set_parse_action(lambda s, lok, tok: ExpressionBlueprint(tok[0])) boolean_literal = ( pp.CaselessLiteral('true') @@ -31,8 +33,8 @@ # Expression -expr_chars = pp.Word(pp.alphanums + "'`,._+- \n\t") -expr_chars_no_comma_space = pp.Word(pp.alphanums + "'`._+-") +expr_chars = pp.Word(pp.alphanums + "\"'`,._+- \n\t") +expr_chars_no_comma_space = pp.Word(pp.alphanums + "\"'`._+-") expression = pp.Forward() factor = ( pp.Word(pp.alphanums + '_')[0, 1] + '(' + expression + ')' diff --git a/pydbml/definitions/index.py b/pydbml/definitions/index.py index 69f704d..ed39fda 100644 --- a/pydbml/definitions/index.py +++ b/pydbml/definitions/index.py @@ -1,7 +1,5 @@ import pyparsing as pp -from pydbml.classes import Index - from .common import _ from .common import _c from .common import c @@ -11,45 +9,52 @@ from .generic import expression_literal from .generic import name from .generic import string_literal +from pydbml.parser.blueprints import ExpressionBlueprint +from pydbml.parser.blueprints import IndexBlueprint -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') index_type = pp.CaselessLiteral("type:").suppress() + _ - ( - pp.CaselessLiteral("btree")('type') | pp.CaselessLiteral("hash")('type') + pp.CaselessLiteral("brin")('type') | + pp.CaselessLiteral("btree")('type') | + pp.CaselessLiteral("gin")('type') | + pp.CaselessLiteral("gist")('type') | + pp.CaselessLiteral("hash")('type') | + pp.CaselessLiteral("spgist")('type') ) -index_setting = ( +index_setting = _ + ( unique('unique') | index_type | pp.CaselessLiteral("name:") + _ - string_literal('name') | note('note') -) + | pk('pk') +) + _ index_settings = ( - '[' + _ + pk('pk') + _ - ']' + c - | '[' + _ + index_setting + (_ + ',' + _ - index_setting)[...] + _ - ']' + c + '[' + index_setting + (',' - index_setting)[...] - ']' + c ) -def parse_index_settings(s, l, t): +def parse_index_settings(s, lok, tok): ''' [type: btree, name: 'name', unique, note: 'note'] ''' result = {} - if 'unique' in t: + if 'unique' in tok: result['unique'] = True - if 'name' in t: - result['name'] = t['name'] - if 'pk' in t: + if 'name' in tok: + result['name'] = tok['name'] + if 'pk' in tok: result['pk'] = True - if 'type' in t: - result['type_'] = t['type'] - if 'note' in t: - result['note'] = t['note'] - if 'comment' in t: - result['comment'] = t['comment'][0] + if 'type' in tok: + result['type'] = tok['type'] + if 'note' in tok: + result['note'] = tok['note'] + if 'comment' in tok: + result['comment'] = tok['comment'][0] return result -index_settings.setParseAction(parse_index_settings) +index_settings.set_parse_action(parse_index_settings) subject = name | expression_literal composite_index_syntax = ( @@ -72,7 +77,7 @@ def parse_index_settings(s, l, t): ) -def parse_index(s, l, t): +def parse_index(s, lok, tok): ''' (id, country) [pk] // composite primary key or @@ -84,22 +89,22 @@ def parse_index(s, l, t): ] ''' init_dict = {} - if isinstance(t['subject'], str): - subjects = [t['subject']] + if isinstance(tok['subject'], (str, ExpressionBlueprint)): + subjects = [tok['subject']] else: - subjects = list(t['subject']) + subjects = list(tok['subject']) init_dict['subject_names'] = subjects - settings = t.get('settings', {}) + settings = tok.get('settings', {}) init_dict.update(settings) # comments after settings have priority - if 'comment' in t: - init_dict['comment'] = t['comment'][0] - if 'comment' not in init_dict and 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + if 'comment' in tok: + init_dict['comment'] = tok['comment'][0] + if 'comment' not in init_dict and 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment - return Index(**init_dict) + return IndexBlueprint(**init_dict) -index.setParseAction(parse_index) +index.set_parse_action(parse_index) diff --git a/pydbml/definitions/project.py b/pydbml/definitions/project.py index 2f05b56..12fa2fd 100644 --- a/pydbml/definitions/project.py +++ b/pydbml/definitions/project.py @@ -1,20 +1,20 @@ import pyparsing as pp -from pydbml.classes import Note -from pydbml.classes import Project - from .common import _ from .common import _c from .common import n from .common import note +from .common import note_object from .generic import name from .generic import string_literal +from pydbml.parser.blueprints import NoteBlueprint +from pydbml.parser.blueprints import ProjectBlueprint -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') project_field = pp.Group(name + _ + pp.Suppress(':') + _ - string_literal) -project_element = _ + (note | project_field) + _ +project_element = _ + (note | note_object | project_field) + _ project_body = project_element[...] @@ -27,27 +27,27 @@ ) + (n | pp.StringEnd()) -def parse_project(s, l, t): +def parse_project(s, loc, tok): ''' Project project_name { database_type: 'PostgreSQL' Note: 'Description of the project' } ''' - init_dict = {'name': t['name']} + init_dict = {'name': tok['name']} items = {} - for item in t.get('items', []): - if isinstance(item, Note): + for item in tok.get('items', []): + if isinstance(item, NoteBlueprint): init_dict['note'] = item else: k, v = item items[k] = v if items: init_dict['items'] = items - if 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + if 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment - return Project(**init_dict) + return ProjectBlueprint(**init_dict) -project.setParseAction(parse_project) +project.set_parse_action(parse_project) diff --git a/pydbml/definitions/reference.py b/pydbml/definitions/reference.py index 450c528..9b6450e 100644 --- a/pydbml/definitions/reference.py +++ b/pydbml/definitions/reference.py @@ -1,29 +1,45 @@ import pyparsing as pp -from pydbml.classes import ReferenceBlueprint - from .common import _ from .common import _c from .common import c from .common import n from .generic import name +from pydbml.parser.blueprints import ReferenceBlueprint + +pp.ParserElement.set_default_whitespace_chars(' \t\r') -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +relation = pp.oneOf("> - < <>") -relation = pp.oneOf("> - <") -ref_inline = pp.Literal("ref:") - relation('type') - name('table') - '.' - name('field') +col_name = ( + ( + name('schema') + '.' + name('table') + '.' - name('field') + ) | ( + name('table') + '.' + name('field') + ) +) +ref_inline = pp.Literal("ref:") - relation('type') - col_name -def parse_inline_relation(s, l, t): + +def parse_inline_relation(s, loc, tok): ''' ref: < table.column + or + ref: < schema1.table.column ''' - return ReferenceBlueprint(type_=t['type'], - table2=t['table'], - col2=t['field']) + result = { + 'type': tok['type'], + 'inline': True, + 'table2': tok['table'], + 'col2': tok['field'] + } + if 'schema' in tok: + result['schema2'] = tok['schema'] + return ReferenceBlueprint(**result) -ref_inline.setParseAction(parse_inline_relation) +ref_inline.set_parse_action(parse_inline_relation) on_option = ( pp.CaselessLiteral('no action') @@ -48,21 +64,21 @@ def parse_inline_relation(s, l, t): ) -def parse_ref_settings(s, l, t): +def parse_ref_settings(s, loc, tok): ''' [delete: cascade] ''' result = {} - if 'update' in t: - result['on_update'] = t['update'][0] - if 'delete' in t: - result['on_delete'] = t['delete'][0] - if 'comment' in t: - result['comment'] = t['comment'][0] + if 'update' in tok: + result['on_update'] = tok['update'][0] + if 'delete' in tok: + result['on_delete'] = tok['delete'][0] + if 'comment' in tok: + result['comment'] = tok['comment'][0] return result -ref_settings.setParseAction(parse_ref_settings) +ref_settings.set_parse_action(parse_ref_settings) composite_name = ( '(' + pp.White()[...] @@ -76,16 +92,53 @@ def parse_ref_settings(s, l, t): ) name_or_composite = name | pp.Combine(composite_name) +ref_cols = ( + ( + name('schema') + + pp.Suppress('.') + name('table') + + pp.Suppress('.') + name_or_composite('field') + ) | ( + name('table') + + pp.Suppress('.') + name_or_composite('field') + ) +) + + +def parse_ref_cols(s, loc, tok): + ''' + table1.col1 + or + schema1.table1.col1 + or + schema1.table1.(col1, col2) + ''' + result = { + 'table': tok['table'], + 'field': tok['field'], + } + if 'schema' in tok: + result['schema'] = tok['schema'] + return result + + +ref_cols.set_parse_action(parse_ref_cols) + ref_body = ( - name('table1') - - '.' - - name_or_composite('field1') + ref_cols('col1') - relation('type') - - name('table2') - - '.' - - name_or_composite('field2') + c + - ref_cols('col2') + c + ref_settings('settings')[0, 1] ) +# ref_body = ( +# table_name('table1') +# - '.' +# - name_or_composite('field1') +# - relation('type') +# - table_name('table2') +# - '.' +# - name_or_composite('field2') + c +# + ref_settings('settings')[0, 1] +# ) ref_short = _c + pp.CaselessLiteral('ref') + name('name')[0, 1] + ':' - ref_body @@ -98,7 +151,7 @@ def parse_ref_settings(s, l, t): ) -def parse_ref(s, l, t): +def parse_ref(s, loc, tok): ''' ref name: table1.col1 > table2.col2 or @@ -107,29 +160,35 @@ def parse_ref(s, l, t): } ''' init_dict = { - 'type_': t['type'], - 'table1': t['table1'], - 'col1': t['field1'], - 'table2': t['table2'], - 'col2': t['field2'] + 'type': tok['type'], + 'inline': False, + 'table1': tok['col1']['table'], + 'col1': tok['col1']['field'], + 'table2': tok['col2']['table'], + 'col2': tok['col2']['field'], } - if 'name' in t: - init_dict['name'] = t['name'] - if 'settings' in t: - init_dict.update(t['settings']) + + if 'schema' in tok['col1']: + init_dict['schema1'] = tok['col1']['schema'] + if 'schema' in tok['col2']: + init_dict['schema2'] = tok['col2']['schema'] + if 'name' in tok: + init_dict['name'] = tok['name'] + if 'settings' in tok: + init_dict.update(tok['settings']) # comments after settings have priority - if 'comment' in t: - init_dict['comment'] = t['comment'][0] - if 'comment' not in init_dict and 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + if 'comment' in tok: + init_dict['comment'] = tok['comment'][0] + if 'comment' not in init_dict and 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment ref = ReferenceBlueprint(**init_dict) return ref -ref_short.setParseAction(parse_ref) -ref_long.setParseAction(parse_ref) +ref_short.set_parse_action(parse_ref) +ref_long.set_parse_action(parse_ref) ref = ref_short | ref_long + (n | pp.StringEnd()) diff --git a/pydbml/definitions/sticky_note.py b/pydbml/definitions/sticky_note.py new file mode 100644 index 0000000..ebd1848 --- /dev/null +++ b/pydbml/definitions/sticky_note.py @@ -0,0 +1,21 @@ +import pyparsing as pp + +from .common import _, end, _c +from .generic import string_literal, name +from ..parser.blueprints import StickyNoteBlueprint + +sticky_note = _c + pp.CaselessLiteral('note') + _ + (name('name') + _ - '{' + _ - string_literal('text') + _ - '}') + end + + +def parse_sticky_note(s, loc, tok): + ''' + Note single_line_note { + 'This is a single line note' + } + ''' + init_dict = {'name': tok['name'], 'text': tok['text']} + + return StickyNoteBlueprint(**init_dict) + + +sticky_note.set_parse_action(parse_sticky_note) diff --git a/pydbml/definitions/table.py b/pydbml/definitions/table.py index adcea9e..aa0dfc9 100644 --- a/pydbml/definitions/table.py +++ b/pydbml/definitions/table.py @@ -1,23 +1,20 @@ import pyparsing as pp -from pydbml.classes import Table - -from .column import table_column -from .common import _ +from pydbml.parser.blueprints import TableBlueprint +from .column import table_column, table_column_with_properties +from .common import _, hex_color from .common import _c from .common import end from .common import note from .common import note_object -from .generic import name +from .generic import name, string_literal from .index import indexes -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') alias = pp.WordStart() + pp.Literal('as').suppress() - pp.WordEnd() - name -hex_char = pp.Word(pp.srange('[0-9a-fA-F]'), exact=1) -hex_color = ("#" - (hex_char * 3 ^ hex_char * 6)).leaveWhitespace() header_color = ( pp.CaselessLiteral('headercolor:').suppress() + _ - pp.Combine(hex_color)('header_color') @@ -26,37 +23,60 @@ table_settings = '[' + table_setting + (',' + table_setting)[...] + ']' -def parse_table_settings(s, l, t): +def parse_table_settings(s, loc, tok): ''' [headercolor: #cccccc, note: 'note'] ''' result = {} - if 'note' in t: - result['note'] = t['note'] - if 'header_color' in t: - result['header_color'] = t['header_color'] + if 'note' in tok: + result['note'] = tok['note'] + if 'header_color' in tok: + result['header_color'] = tok['header_color'] return result -table_settings.setParseAction(parse_table_settings) +table_settings.set_parse_action(parse_table_settings) note_element = note | note_object -table_element = _ + (note_element('note') | indexes('indexes')) + _ +prop = name + pp.Suppress(":") + string_literal + +table_element = _ + ( + table_column.set_results_name('columns', list_all_matches=True) | + note_element('note') | + indexes.set_results_name('indexes', list_all_matches=True) +) + _ +table_element_with_property = _ + ( + table_column_with_properties.set_results_name('columns', list_all_matches=True) | + note_element('note') | + indexes.set_results_name('indexes', list_all_matches=True) | + prop.set_results_name('property', list_all_matches=True) +) + _ + +table_body = table_element[...] +table_body_with_properties = table_element_with_property[...] -table_body = table_column[1, ...]('columns') + _ + table_element[...] +table_name = (name('schema') + '.' + name('name')) | (name('name')) table = _c + ( pp.CaselessLiteral("table").suppress() - + name('name') + + table_name + alias('alias')[0, 1] + table_settings('settings')[0, 1] + _ + '{' - table_body + _ + '}' ) + end +table_with_properties = _c + ( + pp.CaselessLiteral("table").suppress() + + table_name + + alias('alias')[0, 1] + + table_settings('settings')[0, 1] + _ + + '{' - table_body_with_properties + _ + '}' +) + end + -def parse_table(s, l, t): +def parse_table(s, loc, tok): ''' Table bookings as bb [headercolor: #cccccc] { id integer @@ -70,25 +90,34 @@ def parse_table(s, l, t): } ''' init_dict = { - 'name': t['name'], + 'name': tok['name'], } - if 'settings' in t: - init_dict.update(t['settings']) - if 'alias' in t: - init_dict['alias'] = t['alias'][0] - if 'note' in t: + if 'schema' in tok: + init_dict['schema'] = tok['schema'] + if 'settings' in tok: + init_dict.update(tok['settings']) + if 'alias' in tok: + init_dict['alias'] = tok['alias'][0] + if 'note' in tok: # will override one from settings - init_dict['note'] = t['note'][0] - if'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + init_dict['note'] = tok['note'][0] + if 'indexes' in tok: + init_dict['indexes'] = tok['indexes'][0] + if 'columns' in tok: + init_dict['columns'] = tok['columns'] + if 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment - result = Table(**init_dict) - for column in t['columns']: - result.add_column(column) - for index_ in t.get('indexes', []): - result.add_index(index_) + if 'property' in tok: + init_dict['properties'] = {k: v for k, v in tok['property']} + + if not init_dict.get('columns'): + raise SyntaxError(f'Table {init_dict["name"]} at position {loc} has no columns!') + + result = TableBlueprint(**init_dict) return result -table.setParseAction(parse_table) +table.set_parse_action(parse_table) +table_with_properties.set_parse_action(parse_table) diff --git a/pydbml/definitions/table_group.py b/pydbml/definitions/table_group.py index 63f6710..b8e5b16 100644 --- a/pydbml/definitions/table_group.py +++ b/pydbml/definitions/table_group.py @@ -1,24 +1,40 @@ import pyparsing as pp -from pydbml.classes import TableGroup - -from .common import _ +from pydbml.parser.blueprints import TableGroupBlueprint, NoteBlueprint +from .common import _, note, note_object, hex_color from .common import _c from .common import end from .generic import name -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') +pp.ParserElement.set_default_whitespace_chars(' \t\r') + +table_name = pp.Combine(name + '.' + name) | name +note_element = note | note_object + +tg_element = _ + (note_element('note') | table_name.set_results_name('items', list_all_matches=True)) + _ + +tg_body = tg_element[...] + + +tg_color = ( + pp.CaselessLiteral('color:').suppress() + _ + - pp.Combine(hex_color)('color') +) +tg_setting = _ + (note('note') | tg_color) + _ + +tg_settings = '[' + tg_setting + (',' + tg_setting)[...] + ']' table_group = _c + ( pp.CaselessLiteral('TableGroup') - name('name') + _ + + tg_settings[0, 1] + _ - '{' + _ - - (name + _)[...]('items') + _ + - tg_body + _ - '}' ) + end -def parse_table_group(s, l, t): +def parse_table_group(s, loc, tok): ''' TableGroup tablegroup_name { table1 @@ -27,13 +43,18 @@ def parse_table_group(s, l, t): } ''' init_dict = { - 'name': t['name'], - 'items': list(t.get('items', [])) + 'name': tok['name'], + 'items': list(tok.get('items', [])) } - if 'comment_before' in t: - comment = '\n'.join(c[0] for c in t['comment_before']) + if 'comment_before' in tok: + comment = '\n'.join(c[0] for c in tok['comment_before']) init_dict['comment'] = comment - return TableGroup(**init_dict) + if 'note' in tok: + note = tok['note'] + init_dict['note'] = note if isinstance(note, NoteBlueprint) else note[0] + if 'color' in tok: + init_dict['color'] = tok['color'] + return TableGroupBlueprint(**init_dict) -table_group.setParseAction(parse_table_group) +table_group.set_parse_action(parse_table_group) diff --git a/pydbml/exceptions.py b/pydbml/exceptions.py index b1acf7e..b5914f7 100644 --- a/pydbml/exceptions.py +++ b/pydbml/exceptions.py @@ -6,9 +6,29 @@ class ColumnNotFoundError(Exception): pass +class IndexNotFoundError(Exception): + pass + + class AttributeMissingError(Exception): pass class DuplicateReferenceError(Exception): pass + + +class UnknownDatabaseError(Exception): + pass + + +class DBMLError(Exception): + pass + + +class DatabaseValidationError(Exception): + pass + + +class ValidationError(Exception): + pass diff --git a/pydbml/parser.py b/pydbml/parser.py deleted file mode 100644 index 3b969ee..0000000 --- a/pydbml/parser.py +++ /dev/null @@ -1,270 +0,0 @@ -from __future__ import annotations - -import pyparsing as pp - -from io import TextIOWrapper -from pathlib import Path - -from typing import Dict -from typing import List -from typing import Optional -from typing import Union - -from .classes import Enum -from .classes import Project -from .classes import Reference -from .classes import ReferenceBlueprint -from .classes import Table -from .classes import TableGroup -from .classes import TableReference -from .definitions.common import _ -from .definitions.common import comment -from .definitions.enum import enum -from .definitions.project import project -from .definitions.reference import ref -from .definitions.table import table -from .definitions.table_group import table_group -from .exceptions import ColumnNotFoundError -from .exceptions import TableNotFoundError - -pp.ParserElement.setDefaultWhitespaceChars(' \t\r') - - -class PyDBML: - ''' - PyDBML parser factory. If properly initiated, returns PyDBMLParseResults - which contains parse results in attributes. - - Usage option 1: - - >>> with open('schema.dbml') as f: - ... p = PyDBML(f) - ... # or - ... p = PyDBML(f.read()) - - Usage option 2: - >>> p = PyDBML.parse_file('schema.dbml') - >>> # or - >>> from pathlib import Path - >>> p = PyDBML(Path('schema.dbml')) - ''' - - def __new__(cls, - source_: Optional[Union[str, Path, TextIOWrapper]] = None): - if source_ is not None: - if isinstance(source_, str): - source = source_ - elif isinstance(source_, Path): - with open(source_, encoding='utf8') as f: - source = f.read() - else: # TextIOWrapper - source = source_.read() - if source[0] == '\ufeff': # removing BOM - source = source[1:] - return cls.parse(source) - else: - return super().__new__(cls) - - def __repr__(self): - return "" - - @staticmethod - def parse(text: str) -> PyDBMLParseResults: - if text[0] == '\ufeff': # removing BOM - text = text[1:] - return PyDBMLParseResults(text) - - @staticmethod - def parse_file(file: Union[str, Path, TextIOWrapper]): - if isinstance(file, TextIOWrapper): - source = file.read() - else: - with open(file, encoding='utf8') as f: - source = f.read() - if source[0] == '\ufeff': # removing BOM - source = source[1:] - return PyDBMLParseResults(source) - - -class PyDBMLParseResults: - def __init__(self, source: str): - self.tables: List[Table] = [] - self.table_dict: Dict[str, Table] = {} - self.refs: List[Reference] = [] - self.ref_blueprints: List[ReferenceBlueprint] = [] - self.enums: List[Enum] = [] - self.table_groups: List[TableGroup] = [] - self.project: Optional[Project] = None - self.source = source - - self._set_syntax() - self._syntax.parseString(self.source, parseAll=True) - self._validate() - self._process_refs() - self._process_table_groups() - self._set_enum_types() - - def __repr__(self): - return "" - - def _set_syntax(self): - table_expr = table.copy() - ref_expr = ref.copy() - enum_expr = enum.copy() - table_group_expr = table_group.copy() - project_expr = project.copy() - - table_expr.addParseAction(self._parse_table) - ref_expr.addParseAction(self._parse_ref_blueprint) - enum_expr.addParseAction(self._parse_enum) - table_group_expr.addParseAction(self._parse_table_group) - project_expr.addParseAction(self._parse_project) - - expr = ( - table_expr - | ref_expr - | enum_expr - | table_group_expr - | project_expr - ) - self._syntax = expr[...] + ('\n' | comment)[...] + pp.StringEnd() - - def __getitem__(self, k: Union[int, str]) -> Table: - if isinstance(k, int): - return self.tables[k] - else: - return self.table_dict[k] - - def __iter__(self): - return iter(self.tables) - - def _parse_table(self, s, l, t): - table = t[0] - self.tables.append(table) - for col in table.columns: - self.ref_blueprints.extend(col.ref_blueprints) - self.table_dict[table.name] = table - - def _parse_ref_blueprint(self, s, l, t): - self.ref_blueprints.append(t[0]) - - def _parse_enum(self, s, l, t): - self.enums.append(t[0]) - - def _parse_table_group(self, s, l, t): - self.table_groups.append(t[0]) - - def _parse_project(self, s, l, t): - if not self.project: - self.project = t[0] - else: - raise SyntaxError('Project redifinition not allowed') - - def _process_refs(self): - ''' - Fill up the `refs` attribute with Reference object, created from - reference blueprints; - Add TableReference objects to each table which has references. - Validate refs at the same time. - ''' - for ref_ in self.ref_blueprints: - for tb in self.tables: - if tb.name == ref_.table1 or tb.alias == ref_.table1: - table1 = tb - break - else: - raise TableNotFoundError('Error while parsing reference:' - f'table "{ref_.table1}"" is not defined.') - for tb in self.tables: - if tb.name == ref_.table2 or tb.alias == ref_.table2: - table2 = tb - break - else: - raise TableNotFoundError('Error while parsing reference:' - f'table "{ref_.table2}"" is not defined.') - col1_names = [c.strip('() ') for c in ref_.col1.split(',')] - col1 = [] - for col_name in col1_names: - try: - col1.append(table1[col_name]) - except KeyError: - raise ColumnNotFoundError('Error while parsing reference:' - f'column "{col_name} not defined in table "{table1.name}".') - col2_names = [c.strip('() ') for c in ref_.col2.split(',')] - col2 = [] - for col_name in col2_names: - try: - col2.append(table2[col_name]) - except KeyError: - raise ColumnNotFoundError('Error while parsing reference:' - f'column "{col_name} not defined in table "{table2.name}".') - self.refs.append( - Reference( - ref_.type, - table1, - col1, - table2, - col2, - name=ref_.name, - comment=ref_.comment, - on_update=ref_.on_update, - on_delete=ref_.on_delete - ) - ) - - if ref_.type in (Reference.MANY_TO_ONE, Reference.ONE_TO_ONE): - table = table1 - init_dict = { - 'col': col1, - 'ref_table': table2, - 'ref_col': col2, - 'name': ref_.name, - 'on_update': ref_.on_update, - 'on_delete': ref_.on_delete - } - else: - table = table2 - init_dict = { - 'col': col2, - 'ref_table': table1, - 'ref_col': col1, - 'name': ref_.name, - 'on_update': ref_.on_update, - 'on_delete': ref_.on_delete - } - table.add_ref( - TableReference(**init_dict) - ) - - def _set_enum_types(self): - enum_dict = {enum.name: enum for enum in self.enums} - for table_ in self.tables: - for col in table_: - if str(col.type) in enum_dict: - col.type = enum_dict[str(col.type)].get_type() - - def _validate(self): - self._validate_table_groups() - - def _validate_table_groups(self): - ''' - Check that all tables, mentioned in the table groups, exist - ''' - for tg in self.table_groups: - for table_name in tg: - if table_name not in self.table_dict: - raise TableNotFoundError(f'Cannot add Table Group "{tg.name}": table "{table_name}" not found.') - - def _process_table_groups(self): - ''' - Fill up each TableGroup's `item` attribute with references to actual tables. - ''' - for tg in self.table_groups: - tg.items = [self[i] for i in tg.items] - - @property - def sql(self): - '''Returs SQL of the parsed results''' - - components = (i.sql for i in (*self.enums, *self.tables)) - return '\n'.join(components) diff --git a/pydbml/parser/__init__.py b/pydbml/parser/__init__.py new file mode 100644 index 0000000..aa03f88 --- /dev/null +++ b/pydbml/parser/__init__.py @@ -0,0 +1 @@ +from .parser import PyDBML diff --git a/pydbml/parser/blueprints.py b/pydbml/parser/blueprints.py new file mode 100644 index 0000000..614a251 --- /dev/null +++ b/pydbml/parser/blueprints.py @@ -0,0 +1,328 @@ +from dataclasses import dataclass +from typing import Any +from typing import Collection +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from typing import Union + +from pydbml.classes import Column +from pydbml.classes import Enum +from pydbml.classes import EnumItem +from pydbml.classes import Expression +from pydbml.classes import Index +from pydbml.classes import Note +from pydbml.classes import Project +from pydbml.classes import Reference +from pydbml.classes import Table +from pydbml.classes import TableGroup +from pydbml._classes.sticky_note import StickyNote +from pydbml.exceptions import ColumnNotFoundError +from pydbml.exceptions import TableNotFoundError +from pydbml.exceptions import ValidationError +from pydbml.tools import remove_indentation +from pydbml.tools import strip_empty_lines + + +class Blueprint: + parser = None + + +@dataclass +class NoteBlueprint(Blueprint): + text: str + + def _preformat_text(self) -> str: + '''Preformat the note text for idempotence''' + result = strip_empty_lines(self.text) + result = remove_indentation(result) + return result + + def build(self) -> 'Note': + text = self._preformat_text() + return Note(text) + + +@dataclass +class StickyNoteBlueprint(Blueprint): + name: str + text: str + + def _preformat_text(self) -> str: + '''Preformat the note text for idempotence''' + result = strip_empty_lines(self.text) + result = remove_indentation(result) + return result + + def build(self) -> StickyNote: + text = self._preformat_text() + name = self.name + return StickyNote(name=name, text=text) + + +@dataclass +class ExpressionBlueprint(Blueprint): + text: str + + def build(self) -> Expression: + return Expression(self.text) + + +@dataclass +class ReferenceBlueprint(Blueprint): + type: Literal['>', '<', '-', '<>'] + inline: bool + name: Optional[str] = None + schema1: str = 'public' + table1: Optional[str] = None + col1: Optional[Union[str, Collection[str]]] = None + schema2: str = 'public' + table2: Optional[str] = None + col2: Optional[Union[str, Collection[str]]] = None + comment: Optional[str] = None + on_update: Optional[str] = None + on_delete: Optional[str] = None + + def build(self) -> 'Reference': + ''' + both tables and columns should be present before build + ''' + if self.table1 is None: + raise TableNotFoundError("Can't build Reference, table1 unknown") + if self.table2 is None: + raise TableNotFoundError("Can't build Reference, table2 unknown") + if self.col1 is None: + raise ColumnNotFoundError("Can't build Reference, col1 unknown") + if self.col2 is None: + raise ColumnNotFoundError("Can't build Reference, col2 unknown") + + if self.parser: + table1 = self.parser.locate_table(self.schema1, self.table1) + else: + raise RuntimeError('Parser is not set') + + col1_list = [c.strip('() ') for c in self.col1.split(',')] + col1 = [table1[col] for col in col1_list] + + table2 = self.parser.locate_table(self.schema2, self.table2) + col2_list = [c.strip('() ') for c in self.col2.split(',')] + col2 = [table2[col] for col in col2_list] + + return Reference( + type=self.type, + inline=self.inline, + col1=col1, + col2=col2, + name=self.name, + comment=self.comment, + on_update=self.on_update, + on_delete=self.on_delete + ) + + +@dataclass +class ColumnBlueprint(Blueprint): + name: str + type: str + unique: bool = False + not_null: bool = False + pk: bool = False + autoinc: bool = False + default: Optional[Any] = None + note: Optional[NoteBlueprint] = None + ref_blueprints: Optional[List[ReferenceBlueprint]] = None + comment: Optional[str] = None + properties: Optional[Dict[str, str]] = None + + def build(self) -> 'Column': + if isinstance(self.default, ExpressionBlueprint): + self.default = self.default.build() + if self.parser: + if '.' in self.type: + schema, name = self.type.split('.') + else: + schema, name = 'public', self.type + for enum in self.parser.database.enums: + if (enum.schema, enum.name) == (schema, name): + self.type = enum + break + return Column( + name=self.name, + type=self.type, + unique=self.unique, + not_null=self.not_null, + pk=self.pk, + autoinc=self.autoinc, + default=self.default, + note=self.note.build() if self.note else None, + comment=self.comment, + properties=self.properties, + ) + + +@dataclass +class IndexBlueprint(Blueprint): + subject_names: List[Union[str, ExpressionBlueprint]] + name: Optional[str] = None + unique: bool = False + type: Optional[ + Literal[ + # https://www.postgresql.org/docs/current/indexes-types.html + "brin", + "btree", + "gin", + "gist", + "hash", + "spgist", + ] + ] = None + pk: bool = False + note: Optional[NoteBlueprint] = None + comment: Optional[str] = None + + table = None + + def build(self) -> 'Index': + return Index( + # TableBlueprint will process subjects + subjects=[], + name=self.name, + unique=self.unique, + type=self.type, + pk=self.pk, + note=self.note.build() if self.note else None, + comment=self.comment + ) + + +@dataclass +class TableBlueprint(Blueprint): + name: str + schema: str = 'public' + columns: Optional[List[ColumnBlueprint]] = None + indexes: Optional[List[IndexBlueprint]] = None + alias: Optional[str] = None + note: Optional[NoteBlueprint] = None + header_color: Optional[str] = None + comment: Optional[str] = None + properties: Optional[Dict[str, str]] = None + + def build(self) -> 'Table': + result = Table( + name=self.name, + schema=self.schema, + alias=self.alias, + note=self.note.build() if self.note else None, + header_color=self.header_color, + comment=self.comment, + properties=self.properties + ) + columns = self.columns or [] + indexes = self.indexes or [] + for col_bp in columns: + result.add_column(col_bp.build()) + for index_bp in indexes: + index = index_bp.build() + new_subjects: List[Union[str, Column, Expression]] = [] + for subj in index_bp.subject_names: + if isinstance(subj, ExpressionBlueprint): + new_subjects.append(subj.build()) + else: + for col in result.columns: + if col.name == subj: + new_subjects.append(col) + break + else: + raise ColumnNotFoundError( + f'Cannot add index, column "{subj}" not defined in' + ' table "{self.name}".' + ) + index.subjects = new_subjects + result.add_index(index) + return result + + def get_reference_blueprints(self): + ''' the inline ones ''' + result = [] + for col in self.columns: + for ref_bp in col.ref_blueprints or []: + ref_bp.schema1 = self.schema + ref_bp.table1 = self.name + ref_bp.col1 = col.name + result.append(ref_bp) + return result + + +@dataclass +class EnumItemBlueprint(Blueprint): + name: str + note: Optional[NoteBlueprint] = None + comment: Optional[str] = None + + def build(self) -> 'EnumItem': + return EnumItem( + name=self.name, + note=self.note.build() if self.note else None, + comment=self.comment + ) + + +@dataclass +class EnumBlueprint(Blueprint): + name: str + items: List[EnumItemBlueprint] + schema: str = 'public' + comment: Optional[str] = None + + def build(self) -> 'Enum': + return Enum( + name=self.name, + items=[ei.build() for ei in self.items], + schema=self.schema, + comment=self.comment + ) + + +@dataclass +class ProjectBlueprint(Blueprint): + name: str + items: Optional[Dict[str, str]] = None + note: Optional[NoteBlueprint] = None + comment: Optional[str] = None + + def build(self) -> 'Project': + return Project( + name=self.name, + items=dict(self.items) if self.items else {}, + note=self.note.build() if self.note else None, + comment=self.comment + ) + + +@dataclass +class TableGroupBlueprint(Blueprint): + name: str + items: List[str] + comment: Optional[str] = None + note: Optional[NoteBlueprint] = None + color: Optional[str] = None + + def build(self) -> 'TableGroup': + if not self.parser: + raise RuntimeError('Parser is not set') + items = [] + for table_name in self.items: + components = table_name.split('.') + schema, table = components if len(components) == 2 else ('public', components[0]) + table_obj = self.parser.locate_table(schema, table) + if table_obj in items: + raise ValidationError(f'Table "{table}" is already in group "{self.name}"') + items.append(table_obj) + return TableGroup( + name=self.name, + items=items, + comment=self.comment, + note=self.note.build() if self.note else None, + color=self.color + ) diff --git a/pydbml/parser/parser.py b/pydbml/parser/parser.py new file mode 100644 index 0000000..53a5e92 --- /dev/null +++ b/pydbml/parser/parser.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +from io import TextIOWrapper +from pathlib import Path +from typing import List +from typing import Optional +from typing import Type +from typing import Union + +import pyparsing as pp + +from pydbml.classes import Table +from pydbml.database import Database +from pydbml.definitions.common import comment +from pydbml.definitions.enum import enum +from pydbml.definitions.project import project +from pydbml.definitions.reference import ref +from pydbml.definitions.sticky_note import sticky_note +from pydbml.definitions.table import table, table_with_properties +from pydbml.definitions.table_group import table_group +from pydbml.exceptions import TableNotFoundError +from pydbml.renderer.base import BaseRenderer +from pydbml.renderer.dbml.default import DefaultDBMLRenderer +from pydbml.renderer.sql.default import DefaultSQLRenderer +from pydbml.tools import remove_bom +from .blueprints import EnumBlueprint, StickyNoteBlueprint +from .blueprints import ProjectBlueprint +from .blueprints import ReferenceBlueprint +from .blueprints import TableBlueprint +from .blueprints import TableGroupBlueprint + +pp.ParserElement.set_default_whitespace_chars(" \t\r") + + +class PyDBML: + """ + PyDBML parser factory. If properly initiated, returns parsed Database. + + Usage option 1: + + >>> with open('test_schema.dbml') as f: + ... p = PyDBML(f) + ... # or + ... p = PyDBML(f.read()) + + Usage option 2: + >>> p = PyDBML.parse_file('test_schema.dbml') + >>> # or + >>> from pathlib import Path + >>> p = PyDBML(Path('test_schema.dbml')) + """ + + def __new__( + cls, + source_: Optional[Union[str, Path, TextIOWrapper]] = None, + allow_properties: bool = False, + sql_renderer: Type[BaseRenderer] = DefaultSQLRenderer, + dbml_renderer: Type[BaseRenderer] = DefaultDBMLRenderer, + ): + if source_ is not None: + if isinstance(source_, str): + source = source_ + elif isinstance(source_, Path): + with open(source_, encoding="utf8") as f: + source = f.read() + elif isinstance(source_, TextIOWrapper): + source = source_.read() + else: + raise TypeError("Source must be str, path or file stream") + + source = remove_bom(source) + return cls.parse( + source, + allow_properties=allow_properties, + sql_renderer=sql_renderer, + dbml_renderer=dbml_renderer, + ) + else: + return super().__new__(cls) + + def __repr__(self): + return "" + + @staticmethod + def parse( + text: str, + allow_properties: bool = False, + sql_renderer: Type[BaseRenderer] = DefaultSQLRenderer, + dbml_renderer: Type[BaseRenderer] = DefaultDBMLRenderer, + ) -> Database: + text = remove_bom(text) + parser = PyDBMLParser( + text, + allow_properties=allow_properties, + sql_renderer=sql_renderer, + dbml_renderer=dbml_renderer, + ) + return parser.parse() + + @staticmethod + def parse_file(file: Union[str, Path, TextIOWrapper]) -> Database: + if isinstance(file, TextIOWrapper): + source = file.read() + else: + with open(file, encoding="utf8") as f: + source = f.read() + source = remove_bom(source) + parser = PyDBMLParser(source) + return parser.parse() + + +class PyDBMLParser: + def __init__( + self, + source: str, + allow_properties: bool = False, + sql_renderer: Type[BaseRenderer] = DefaultSQLRenderer, + dbml_renderer: Type[BaseRenderer] = DefaultDBMLRenderer, + ): + self.database = None + + self.ref_blueprints: List[ReferenceBlueprint] = [] + self.table_groups: List[TableGroupBlueprint] = [] + self.source = source + self.tables: List[TableBlueprint] = [] + self.refs: List[ReferenceBlueprint] = [] + self.enums: List[EnumBlueprint] = [] + self.project: Optional[ProjectBlueprint] = None + self.sticky_notes: List[StickyNoteBlueprint] = [] + self._allow_properties = allow_properties + self._sql_renderer = sql_renderer + self._dbml_renderer = dbml_renderer + + def parse(self): + self._set_syntax() + self._syntax.parse_string(self.source, parseAll=True) + self.build_database() + return self.database + + def __repr__(self): + return "" + + def _set_syntax(self): + table_expr = ( + table_with_properties.copy() if self._allow_properties else table.copy() + ) + ref_expr = ref.copy() + enum_expr = enum.copy() + table_group_expr = table_group.copy() + project_expr = project.copy() + note_expr = sticky_note.copy() + + table_expr.addParseAction(self.parse_blueprint) + ref_expr.addParseAction(self.parse_blueprint) + enum_expr.addParseAction(self.parse_blueprint) + table_group_expr.addParseAction(self.parse_blueprint) + project_expr.addParseAction(self.parse_blueprint) + note_expr.addParseAction(self.parse_blueprint) + + expr = ( + table_expr + | ref_expr + | enum_expr + | table_group_expr + | project_expr + | note_expr + ) + self._syntax = expr[...] + ("\n" | comment)[...] + pp.StringEnd() + + def parse_blueprint(self, s, loc, tok): + blueprint = tok[0] + if isinstance(blueprint, TableBlueprint): + self.tables.append(blueprint) + ref_bps = blueprint.get_reference_blueprints() + col_bps = blueprint.columns or [] + index_bps = blueprint.indexes or [] + for ref_bp in ref_bps: + self.refs.append(ref_bp) + ref_bp.parser = self + for col_bp in col_bps: + col_bp.parser = self + if col_bp.note: + col_bp.note.parser = self + for index_bp in index_bps: + index_bp.parser = self + if index_bp.note: + index_bp.note.parser = self + if blueprint.note: + blueprint.note.parser = self + elif isinstance(blueprint, ReferenceBlueprint): + self.refs.append(blueprint) + elif isinstance(blueprint, EnumBlueprint): + self.enums.append(blueprint) + for enum_item in blueprint.items: + if enum_item.note: + enum_item.note.parser = self + elif isinstance(blueprint, TableGroupBlueprint): + self.table_groups.append(blueprint) + elif isinstance(blueprint, ProjectBlueprint): + self.project = blueprint + if blueprint.note: + blueprint.note.parser = self + elif isinstance(blueprint, StickyNoteBlueprint): + self.sticky_notes.append(blueprint) + else: + raise RuntimeError(f"type unknown: {blueprint}") + blueprint.parser = self + + def locate_table(self, schema: str, name: str) -> "Table": + if not self.database: + raise RuntimeError("Database is not ready") + # first by alias + result = self.database.table_dict.get(name) + if result is None: + full_name = f"{schema}.{name}" + result = self.database.table_dict.get(full_name) + if result is None: + raise TableNotFoundError(f"Table {full_name} not present in the database") + return result + + def build_database(self): + self.database = Database( + allow_properties=self._allow_properties, + sql_renderer=self._sql_renderer, + dbml_renderer=self._dbml_renderer, + ) + for enum_bp in self.enums: + self.database.add(enum_bp.build()) + for table_bp in self.tables: + self.database.add(table_bp.build()) + self.ref_blueprints.extend(table_bp.get_reference_blueprints()) + for table_group_bp in self.table_groups: + self.database.add(table_group_bp.build()) + for note_bp in self.sticky_notes: + self.database.add(note_bp.build()) + if self.project: + self.database.add(self.project.build()) + for ref_bp in self.refs: + self.database.add(ref_bp.build()) diff --git a/pydbml/renderer/__init__.py b/pydbml/renderer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pydbml/renderer/base.py b/pydbml/renderer/base.py new file mode 100644 index 0000000..4f1dd28 --- /dev/null +++ b/pydbml/renderer/base.py @@ -0,0 +1,38 @@ +from typing import Type, Callable, Dict, TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from pydbml.database import Database + + +def unsupported_renderer(model) -> str: + return '' + + +class BaseRenderer: + _unsupported_renderer = unsupported_renderer + + @property + def model_renderers(cls) -> Dict[Type, Callable]: + """A class attribute dictionary to store the model renderers.""" + raise NotImplementedError # pragma: no cover + + @classmethod + def render(cls, model) -> str: + """ + Render the model to a string. If the model is not supported, fall back to + `self._unsupported_renderer` that by default returns an empty string. + """ + + return cls.model_renderers.get(type(model), cls._unsupported_renderer)(model) # type: ignore + + @classmethod + def renderer_for(cls, model_cls: Type) -> Callable: + """A decorator to register a renderer for a model class.""" + def decorator(func) -> Callable: + cls.model_renderers[model_cls] = func # type: ignore + return func + return decorator + + @classmethod + def render_db(cls, db: 'Database') -> str: + raise NotImplementedError # pragma: no cover diff --git a/pydbml/renderer/dbml/__init__.py b/pydbml/renderer/dbml/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pydbml/renderer/dbml/default/__init__.py b/pydbml/renderer/dbml/default/__init__.py new file mode 100644 index 0000000..ca43367 --- /dev/null +++ b/pydbml/renderer/dbml/default/__init__.py @@ -0,0 +1,11 @@ +from .column import render_column +from .enum import render_enum, render_enum_item +from .expression import render_expression +from .index import render_index +from .note import render_note +from .project import render_project +from .reference import render_reference +from .renderer import DefaultDBMLRenderer +from .sticky_note import render_sticky_note +from .table import render_table +from .table_group import render_table_group diff --git a/pydbml/renderer/dbml/default/column.py b/pydbml/renderer/dbml/default/column.py new file mode 100644 index 0000000..3ec037a --- /dev/null +++ b/pydbml/renderer/dbml/default/column.py @@ -0,0 +1,55 @@ +from typing import Union + +from pydbml.classes import Column, Enum, Expression +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import comment_to_dbml, note_option_to_dbml, quote_string, prepare_text_for_dbml +from pydbml.renderer.sql.default.utils import get_full_name_for_sql + + +def default_to_str(val: Union[Expression, str, int, float]) -> str: + if isinstance(val, str): + if val.lower() in ('null', 'true', 'false'): + return val.lower() + else: + return f"'{prepare_text_for_dbml(val)}'" + elif isinstance(val, Expression): + return val.dbml + else: # int or float or bool + return str(val) + + +def render_options(model: Column) -> str: + options = [ref.dbml for ref in model.get_refs() if ref.inline] + if model.pk: + options.append('pk') + if model.autoinc: + options.append('increment') + if model.default: + options.append(f'default: {default_to_str(model.default)}') + if model.unique: + options.append('unique') + if model.not_null: + options.append('not null') + if model.note: + options.append(note_option_to_dbml(model.note)) + if model.properties: + if model.table and model.table.database and model.table.database.allow_properties: + for key, value in model.properties.items(): + options.append(f'{key}: {quote_string(value)}') + + if options: + return f' [{", ".join(options)}]' + return '' + + +@DefaultDBMLRenderer.renderer_for(Column) +def render_column(model: Column) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + result += f'"{model.name}" ' + if isinstance(model.type, Enum): + result += get_full_name_for_sql(model.type) + else: + result += model.type + + result += render_options(model) + return result diff --git a/pydbml/renderer/dbml/default/enum.py b/pydbml/renderer/dbml/default/enum.py new file mode 100644 index 0000000..f89c310 --- /dev/null +++ b/pydbml/renderer/dbml/default/enum.py @@ -0,0 +1,25 @@ +from textwrap import indent + +from pydbml.classes import Enum, EnumItem +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import comment_to_dbml, note_option_to_dbml +from pydbml.renderer.sql.default.utils import get_full_name_for_sql + + +@DefaultDBMLRenderer.renderer_for(Enum) +def render_enum(model: Enum) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + result += f'Enum {get_full_name_for_sql(model)} {{\n' + items_str = '\n'.join(DefaultDBMLRenderer.render(i) for i in model.items) + result += indent(items_str, ' ') + result += '\n}' + return result + + +@DefaultDBMLRenderer.renderer_for(EnumItem) +def render_enum_item(model: EnumItem) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + result += f'"{model.name}"' + if model.note: + result += f' [{note_option_to_dbml(model.note)}]' + return result diff --git a/pydbml/renderer/dbml/default/expression.py b/pydbml/renderer/dbml/default/expression.py new file mode 100644 index 0000000..627f286 --- /dev/null +++ b/pydbml/renderer/dbml/default/expression.py @@ -0,0 +1,7 @@ +from pydbml.classes import Expression +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer + + +@DefaultDBMLRenderer.renderer_for(Expression) +def render_expression(model: Expression) -> str: + return f'`{model.text}`' diff --git a/pydbml/renderer/dbml/default/index.py b/pydbml/renderer/dbml/default/index.py new file mode 100644 index 0000000..74d6081 --- /dev/null +++ b/pydbml/renderer/dbml/default/index.py @@ -0,0 +1,49 @@ +from typing import List, Any + +from pydbml.classes import Index, Expression, Column +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import comment_to_dbml, note_option_to_dbml + + +def render_subjects(source_subjects: List[Any]) -> str: + subjects = [] + + for subj in source_subjects: + if isinstance(subj, Column): + subjects.append(subj.name) + elif isinstance(subj, Expression): + subjects.append(DefaultDBMLRenderer.render(subj)) + else: + subjects.append(subj) + + if len(subjects) > 1: + return f'({", ".join(subj for subj in subjects)})' + else: + return subjects[0] + + +def render_options(model: Index) -> str: + options = [] + if model.name: + options.append(f"name: '{model.name}'") + if model.pk: + options.append('pk') + if model.unique: + options.append('unique') + if model.type: + options.append(f'type: {model.type}') + if model.note: + options.append(note_option_to_dbml(model.note)) + + if options: + return f' [{", ".join(options)}]' + return '' + + +@DefaultDBMLRenderer.renderer_for(Index) +def render_index(model: Index) -> str: + return ( + (comment_to_dbml(model.comment) if model.comment else '') + + render_subjects(model.subjects) + + render_options(model) + ) diff --git a/pydbml/renderer/dbml/default/note.py b/pydbml/renderer/dbml/default/note.py new file mode 100644 index 0000000..0bff77f --- /dev/null +++ b/pydbml/renderer/dbml/default/note.py @@ -0,0 +1,14 @@ +from textwrap import indent + +from pydbml.classes import Note +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import quote_string + + +@DefaultDBMLRenderer.renderer_for(Note) +def render_note(model: Note) -> str: + text = quote_string(model.text) + + text = indent(text, ' ') + result = f'Note {{\n{text}\n}}' + return result diff --git a/pydbml/renderer/dbml/default/project.py b/pydbml/renderer/dbml/default/project.py new file mode 100644 index 0000000..f885b10 --- /dev/null +++ b/pydbml/renderer/dbml/default/project.py @@ -0,0 +1,29 @@ +from textwrap import indent +from typing import Dict + +from pydbml.classes import Project +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import comment_to_dbml +from pydbml.tools import doublequote_string + + +def render_items(items: Dict[str, str]) -> str: + items_str = '' + for k, v in items.items(): + if '\n' in v: + items_str += f"{k}: '''{v}'''\n" + else: + items_str += f"{k}: '{v}'\n" + return indent(items_str.rstrip('\n'), ' ') + '\n' + + +@DefaultDBMLRenderer.renderer_for(Project) +def render_project(model: Project) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + quoted_name = doublequote_string(model.name) + result += f'Project {quoted_name} {{\n' + result += render_items(model.items) + if model.note: + result += indent(DefaultDBMLRenderer.render(model.note), ' ') + '\n' + result += '}' + return result diff --git a/pydbml/renderer/dbml/default/reference.py b/pydbml/renderer/dbml/default/reference.py new file mode 100644 index 0000000..abf8874 --- /dev/null +++ b/pydbml/renderer/dbml/default/reference.py @@ -0,0 +1,68 @@ +from itertools import chain +from textwrap import indent +from typing import List + +from pydbml.classes import Reference, Column +from pydbml.exceptions import TableNotFoundError, DBMLError +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import comment_to_dbml +from .table import get_full_name_for_dbml + + +def validate_for_dbml(model: Reference): + for col in chain(model.col1, model.col2): + if col.table is None: + raise TableNotFoundError(f'Table on {col} is not set') + + +def render_inline_reference(model: Reference) -> str: + # settings are ignored for inline ref + if len(model.col2) > 1: + raise DBMLError('Cannot render DBML: composite ref cannot be inline') + table_name = get_full_name_for_dbml(model.col2[0].table) + return f'ref: {model.type} {table_name}."{model.col2[0].name}"' + + +def render_col(col: List[Column]) -> str: + if len(col) == 1: + return f'"{col[0].name}"' + else: + names = (f'"{c.name}"' for c in col) + return f'({", ".join(names)})' + + +def render_options(model: Reference) -> str: + options = [] + if model.on_update: + options.append(f'update: {model.on_update}') + if model.on_delete: + options.append(f'delete: {model.on_delete}') + if options: + return f' [{", ".join(options)}]' + return '' + + +def render_not_inline_reference(model: Reference) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + result += 'Ref' + if model.name: + result += f' {model.name}' + + result += ( + ' {\n ' # type: ignore + f'{get_full_name_for_dbml(model.table1)}.{render_col(model.col1)} ' + f'{model.type} ' + f'{get_full_name_for_dbml(model.table2)}.{render_col(model.col2)}' + f'{render_options(model)}' + '\n}' + ) + return result + + +@DefaultDBMLRenderer.renderer_for(Reference) +def render_reference(model: Reference) -> str: + validate_for_dbml(model) + if model.inline: + return render_inline_reference(model) + else: + return render_not_inline_reference(model) diff --git a/pydbml/renderer/dbml/default/renderer.py b/pydbml/renderer/dbml/default/renderer.py new file mode 100644 index 0000000..0445c75 --- /dev/null +++ b/pydbml/renderer/dbml/default/renderer.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING, List + +from pydbml.renderer.base import BaseRenderer +from pydbml._classes.base import DBMLObject + +if TYPE_CHECKING: # pragma: no cover + from pydbml.database import Database + + +class DefaultDBMLRenderer(BaseRenderer): + model_renderers = {} + + @classmethod + def render_db(cls, db: 'Database') -> str: + items: List[DBMLObject] = [db.project] if db.project else [] + refs = (ref for ref in db.refs if not ref.inline) + items.extend((*db.enums, *db.tables, *refs, *db.table_groups, *db.sticky_notes)) + + return '\n\n'.join(cls.render(i) for i in items) diff --git a/pydbml/renderer/dbml/default/sticky_note.py b/pydbml/renderer/dbml/default/sticky_note.py new file mode 100644 index 0000000..36d1122 --- /dev/null +++ b/pydbml/renderer/dbml/default/sticky_note.py @@ -0,0 +1,15 @@ +from textwrap import indent + + +from pydbml.classes import StickyNote +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import quote_string + + +@DefaultDBMLRenderer.renderer_for(StickyNote) +def render_sticky_note(model: StickyNote) -> str: + text = quote_string(model.text) + + text = indent(text, ' ') + result = f'Note {model.name} {{\n{text}\n}}' + return result diff --git a/pydbml/renderer/dbml/default/table.py b/pydbml/renderer/dbml/default/table.py new file mode 100644 index 0000000..c896066 --- /dev/null +++ b/pydbml/renderer/dbml/default/table.py @@ -0,0 +1,58 @@ +import re +from textwrap import indent + +from pydbml.classes import Table +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.utils import comment_to_dbml, quote_string + + +def get_full_name_for_dbml(model) -> str: + if model.schema == 'public': + return f'"{model.name}"' + else: + return f'"{model.schema}"."{model.name}"' + + +def render_header(model: Table) -> str: + name = get_full_name_for_dbml(model) + + result = f'Table {name} ' + if model.alias: + result += f'as "{model.alias}" ' + if model.header_color: + result += f'[headercolor: {model.header_color}] ' + return result + + +def render_indexes(model: Table) -> str: + if model.indexes: + result = '\n indexes {\n' + indexes_str = '\n'.join(DefaultDBMLRenderer.render(i) for i in model.indexes) + result += indent(indexes_str, ' ') + '\n' + result += ' }\n' + return result + return '' + + +@DefaultDBMLRenderer.renderer_for(Table) +def render_table(model: Table) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + result += render_header(model) + + result += '{\n' + columns_str = '\n'.join(DefaultDBMLRenderer.render(c) for c in model.columns) + result += indent(columns_str, ' ') + '\n' + + if model.properties: + if model.database and model.database.allow_properties: + properties_str = '\n' + '\n'.join(f'{key}: {quote_string(value)}' for key, value in model.properties.items()) + '\n' + properties_str = indent(properties_str, ' ') + result += properties_str + + if model.note: + result += indent(model.note.dbml, ' ') + '\n' + + result += render_indexes(model) + + result += '}' + return result diff --git a/pydbml/renderer/dbml/default/table_group.py b/pydbml/renderer/dbml/default/table_group.py new file mode 100644 index 0000000..ca7cb44 --- /dev/null +++ b/pydbml/renderer/dbml/default/table_group.py @@ -0,0 +1,23 @@ +from textwrap import indent + +from pydbml.classes import TableGroup +from pydbml.renderer.dbml.default.renderer import DefaultDBMLRenderer +from pydbml.renderer.dbml.default.table import get_full_name_for_dbml +from pydbml.renderer.dbml.default.utils import comment_to_dbml +from pydbml.tools import doublequote_string + + +@DefaultDBMLRenderer.renderer_for(TableGroup) +def render_table_group(model: TableGroup) -> str: + result = comment_to_dbml(model.comment) if model.comment else '' + quoted_name = doublequote_string(model.name) + result += f'TableGroup {quoted_name}' + if model.color: + result += f' [color: {model.color}]' + result += ' {\n' + for i in model.items: + result += f' {get_full_name_for_dbml(i)}\n' + if model.note: + result += indent(model.note.dbml, ' ') + '\n' + result += '}' + return result diff --git a/pydbml/renderer/dbml/default/utils.py b/pydbml/renderer/dbml/default/utils.py new file mode 100644 index 0000000..3e71c47 --- /dev/null +++ b/pydbml/renderer/dbml/default/utils.py @@ -0,0 +1,31 @@ +import re +from typing import TYPE_CHECKING + +from pydbml.tools import comment + +if TYPE_CHECKING: # pragma: no cover + from pydbml.classes import Note + + +def prepare_text_for_dbml(text: str) -> str: + '''Escape single quotes''' + pattern = re.compile(r"('''|')") + return pattern.sub(r'\\\1', text) + + +def quote_string(text: str) -> str: + if '\n' in text: + return f"'''\n{prepare_text_for_dbml(text)}'''" + else: + return f"'{prepare_text_for_dbml(text)}'" + + +def note_option_to_dbml(note: 'Note') -> str: + if '\n' in note.text: + return f"note: '''{prepare_text_for_dbml(note.text)}'''" + else: + return f"note: '{prepare_text_for_dbml(note.text)}'" + + +def comment_to_dbml(val: str) -> str: + return comment(val, '//') diff --git a/pydbml/renderer/sql/__init__.py b/pydbml/renderer/sql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pydbml/renderer/sql/default/__init__.py b/pydbml/renderer/sql/default/__init__.py new file mode 100644 index 0000000..c3b1e48 --- /dev/null +++ b/pydbml/renderer/sql/default/__init__.py @@ -0,0 +1,8 @@ +from .renderer import DefaultSQLRenderer +from .column import render_column +from .enum import render_enum, render_enum_item +from .expression import render_expression +from .index import render_index +from .note import render_note +from .reference import render_reference +from .table import render_table diff --git a/pydbml/renderer/sql/default/column.py b/pydbml/renderer/sql/default/column.py new file mode 100644 index 0000000..d04b061 --- /dev/null +++ b/pydbml/renderer/sql/default/column.py @@ -0,0 +1,39 @@ +from pydbml.classes import Column, Enum, Expression +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer +from .utils import comment_to_sql +from .enum import get_full_name_for_sql as get_full_name_for_sql_enum + + +@DefaultSQLRenderer.renderer_for(Column) +def render_column(model: Column) -> str: + ''' + Returns inline SQL of the column, which should be a part of table definition: + + "id" integer PRIMARY KEY AUTOINCREMENT + ''' + + components = [f'"{model.name}"'] + if isinstance(model.type, Enum): + components.append(get_full_name_for_sql_enum(model.type)) + else: + components.append(str(model.type)) + + table_has_composite_pk = model.table._has_composite_pk() if model.table else False + if model.pk and not table_has_composite_pk: # composite PKs are rendered in table sql + components.append('PRIMARY KEY') + if model.autoinc: + components.append('AUTOINCREMENT') + if model.unique: + components.append('UNIQUE') + if model.not_null: + components.append('NOT NULL') + if model.default is not None: + if isinstance(model.default, Expression): + default = DefaultSQLRenderer.render(model.default) + else: + default = model.default # type: ignore + components.append(f'DEFAULT {default}') + + result = comment_to_sql(model.comment) if model.comment else '' + result += ' '.join(components) + return result diff --git a/pydbml/renderer/sql/default/enum.py b/pydbml/renderer/sql/default/enum.py new file mode 100644 index 0000000..5d1b3ba --- /dev/null +++ b/pydbml/renderer/sql/default/enum.py @@ -0,0 +1,34 @@ +from textwrap import indent + +from pydbml._classes.enum import EnumItem +from pydbml.classes import Enum +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer +from pydbml.renderer.sql.default.utils import comment_to_sql, get_full_name_for_sql + + +@DefaultSQLRenderer.renderer_for(Enum) +def render_enum(model: Enum) -> str: + ''' + Returns SQL for enum type: + + CREATE TYPE "job_status" AS ENUM ( + 'created', + 'running', + 'done', + 'failure', + ); + ''' + + result = comment_to_sql(model.comment) if model.comment else '' + result += f'CREATE TYPE {get_full_name_for_sql(model)} AS ENUM (\n' + enum_body = '\n'.join(f'{indent(DefaultSQLRenderer.render(i), " ")}' for i in model.items) + result += enum_body.rstrip(',') + result += '\n);' + return result + + +@DefaultSQLRenderer.renderer_for(EnumItem) +def render_enum_item(model: EnumItem) -> str: + result = comment_to_sql(model.comment) if model.comment else '' + result += f"'{model.name}'," + return result diff --git a/pydbml/renderer/sql/default/expression.py b/pydbml/renderer/sql/default/expression.py new file mode 100644 index 0000000..b080ee7 --- /dev/null +++ b/pydbml/renderer/sql/default/expression.py @@ -0,0 +1,7 @@ +from pydbml.classes import Expression +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer + + +@DefaultSQLRenderer.renderer_for(Expression) +def render_expression(model: Expression) -> str: + return f'({model.text})' diff --git a/pydbml/renderer/sql/default/index.py b/pydbml/renderer/sql/default/index.py new file mode 100644 index 0000000..3202496 --- /dev/null +++ b/pydbml/renderer/sql/default/index.py @@ -0,0 +1,65 @@ +from typing import Any + +from pydbml.classes import Expression, Index, Column +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer +from pydbml.renderer.sql.default.utils import comment_to_sql + + +def render_subject(subject: Any) -> str: + if isinstance(subject, Column): + return f'"{subject.name}"' + elif isinstance(subject, Expression): + return DefaultSQLRenderer.render(subject) + else: + return subject + + +def render_pk(model: Index, keys: str) -> str: + result = comment_to_sql(model.comment) if model.comment else '' + result += f'PRIMARY KEY ({keys})' + return result + + +def create_components(model: Index, keys: str) -> str: + components = [] + if model.comment: + components.append(comment_to_sql(model.comment)) + + components.append('CREATE ') + + if model.unique: + components.append('UNIQUE ') + + components.append('INDEX ') + + if model.name: + components.append(f'"{model.name}" ') + if model.table: + components.append(f'ON "{model.table.name}" ') + + if model.type: + components.append(f'USING {model.type.upper()} ') + components.append(f'({keys})') + return ''.join(components) + ';' + + +@DefaultSQLRenderer.renderer_for(Index) +def render_index(model: Index) -> str: + ''' + Returns inline SQL of the index to be created separately from table + definition: + + CREATE UNIQUE INDEX ON "products" USING HASH ("id"); + + But if it's a (composite) primary key index, returns an inline SQL for + composite primary key to be used inside table definition: + + PRIMARY KEY ("id", "name") + ''' + + keys = ', '.join(render_subject(s) for s in model.subjects) + + if model.pk: + return render_pk(model, keys) + + return create_components(model, keys) diff --git a/pydbml/renderer/sql/default/note.py b/pydbml/renderer/sql/default/note.py new file mode 100644 index 0000000..751bfd5 --- /dev/null +++ b/pydbml/renderer/sql/default/note.py @@ -0,0 +1,43 @@ +import re + +from pydbml.classes import Note, Table, Column +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer + + +def prepare_text_for_sql(model: Note) -> str: + ''' + - Process special escape sequence: slash before line break, which means no line break + https://www.dbml.org/docs/#multi-line-string + - replace all single quotes with double quotes + ''' + + pattern = re.compile(r'\\\n') + result = pattern.sub('', model.text) + + result = result.replace("'", '"') + return result + + +def generate_comment_on(model: Note, entity: str, name: str) -> str: + """Generate a COMMENT ON clause out from this note.""" + quoted_text = f"'{prepare_text_for_sql(model)}'" + note_sql = f'COMMENT ON {entity.upper()} "{name}" IS {quoted_text};' + return note_sql + + +@DefaultSQLRenderer.renderer_for(Note) +def render_note(model: Note) -> str: + """ + For Tables and Columns Note is converted into COMMENT ON clause. All other entities don't + have notes generated in their SQL code, but as a fallback their notes are rendered as SQL + comments when sql property is called directly. + """ + + if model.text: + if isinstance(model.parent, (Table, Column)): + return generate_comment_on(model, model.parent.__class__.__name__, model.parent.name) + else: + text = prepare_text_for_sql(model) + return '\n'.join(f'-- {line}' for line in text.split('\n')) + else: + return '' diff --git a/pydbml/renderer/sql/default/reference.py b/pydbml/renderer/sql/default/reference.py new file mode 100644 index 0000000..ee83f11 --- /dev/null +++ b/pydbml/renderer/sql/default/reference.py @@ -0,0 +1,82 @@ +from itertools import chain +from typing import List + +from pydbml.classes import Reference, Column +from pydbml.constants import MANY_TO_MANY, MANY_TO_ONE, ONE_TO_ONE, ONE_TO_MANY +from pydbml.exceptions import TableNotFoundError +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer +from pydbml.renderer.sql.default.utils import comment_to_sql, get_full_name_for_sql + + +def col_names(cols: List[Column]) -> str: + return ', '.join(f'"{c.name}"' for c in cols) + + +def validate_for_sql(model: Reference): + for col in chain(model.col1, model.col2): + if col.table is None: + raise TableNotFoundError(f'Table on {col} is not set') + + +def generate_inline_sql(model: Reference, source_col: List[Column], ref_col: List[Column]) -> str: + result = comment_to_sql(model.comment) if model.comment else '' + result += ( + f'{{c}}FOREIGN KEY ({col_names(source_col)}) ' # type: ignore + f'REFERENCES {get_full_name_for_sql(ref_col[0].table)} ({col_names(ref_col)})' # type: ignore + ) + if model.on_update: + result += f' ON UPDATE {model.on_update.upper()}' + if model.on_delete: + result += f' ON DELETE {model.on_delete.upper()}' + return result + + +def generate_not_inline_sql(model: Reference, source_col: List['Column'], ref_col: List['Column']): + result = comment_to_sql(model.comment) if model.comment else '' + result += ( + f'ALTER TABLE {get_full_name_for_sql(source_col[0].table)}' # type: ignore + f' ADD {{c}}FOREIGN KEY ({col_names(source_col)})' + f' REFERENCES {get_full_name_for_sql(ref_col[0].table)} ({col_names(ref_col)})' # type: ignore + ) + if model.on_update: + result += f' ON UPDATE {model.on_update.upper()}' + if model.on_delete: + result += f' ON DELETE {model.on_delete.upper()}' + return result + ';' + + +def generate_many_to_many_sql(model: Reference) -> str: + join_table = model.join_table + table_sql = join_table.sql # type: ignore + + n = len(model.col1) + ref1_sql = generate_not_inline_sql(model, join_table.columns[:n], model.col1) # type: ignore + ref2_sql = generate_not_inline_sql(model, join_table.columns[n:], model.col2) # type: ignore + + result = '\n\n'.join((table_sql, ref1_sql, ref2_sql)) + return result.format(c='') + + +@DefaultSQLRenderer.renderer_for(Reference) +def render_reference(model: Reference) -> str: + ''' + Returns SQL of the reference: + + ALTER TABLE "orders" ADD FOREIGN KEY ("customer_id") REFERENCES "customers ("id"); + + ''' + validate_for_sql(model) + + if model.type == MANY_TO_MANY: + return generate_many_to_many_sql(model) + + result = '' + func = generate_inline_sql if model.inline else generate_not_inline_sql + if model.type in (MANY_TO_ONE, ONE_TO_ONE): + result = func(model=model, source_col=model.col1, ref_col=model.col2) + elif model.type == ONE_TO_MANY: + result = func(model=model, source_col=model.col2, ref_col=model.col1) + + c = f'CONSTRAINT "{model.name}" ' if model.name else '' + + return result.format(c=c) diff --git a/pydbml/renderer/sql/default/renderer.py b/pydbml/renderer/sql/default/renderer.py new file mode 100644 index 0000000..188d07d --- /dev/null +++ b/pydbml/renderer/sql/default/renderer.py @@ -0,0 +1,24 @@ +from typing import TYPE_CHECKING + +from pydbml.renderer.sql.default.utils import reorder_tables_for_sql +from pydbml.renderer.base import BaseRenderer + + +if TYPE_CHECKING: # pragma: no cover + from pydbml.database import Database + + +class DefaultSQLRenderer(BaseRenderer): + model_renderers = {} + + @classmethod + def render(cls, model) -> str: + model.check_attributes_for_sql() + return super().render(model) + + @classmethod + def render_db(cls, db: 'Database') -> str: + refs = (ref for ref in db.refs if not ref.inline) + tables = reorder_tables_for_sql(db.tables, db.refs) + components = (cls.render(i) for i in (*db.enums, *tables, *refs)) + return '\n\n'.join(components) diff --git a/pydbml/renderer/sql/default/table.py b/pydbml/renderer/sql/default/table.py new file mode 100644 index 0000000..d208143 --- /dev/null +++ b/pydbml/renderer/sql/default/table.py @@ -0,0 +1,97 @@ +from textwrap import indent +from typing import List + +from pydbml.constants import MANY_TO_ONE, ONE_TO_ONE, ONE_TO_MANY +from pydbml.classes import Table, Reference, Column +from pydbml.exceptions import UnknownDatabaseError +from pydbml.renderer.sql.default.note import prepare_text_for_sql +from pydbml.renderer.sql.default.renderer import DefaultSQLRenderer +from pydbml.renderer.sql.default.utils import comment_to_sql, get_full_name_for_sql + + +def get_references_for_sql(model: Table) -> List[Reference]: + """ + Return all references in the database where this table is on the left side of SQL + reference definition. + """ + if not model.database: + raise UnknownDatabaseError(f'Database for the table {model} is not set') + result = [] + for ref in model.database.refs: + if (ref.type in (MANY_TO_ONE, ONE_TO_ONE)) and\ + (ref.table1 == model): + result.append(ref) + elif (ref.type == ONE_TO_MANY) and (ref.table2 == model): + result.append(ref) + return result + + +def get_inline_references_for_sql(model: Table) -> List[Reference]: + ''' + Return inline references for this table sql definition + ''' + if model.abstract: + return [] + return [r for r in get_references_for_sql(model) if r.inline] + + +def create_body(model: Table) -> str: + body: List[str] = [] + body.extend(indent(DefaultSQLRenderer.render(c), " ") for c in model.columns) + body.extend(indent(DefaultSQLRenderer.render(i), " ") for i in model.indexes if i.pk) + body.extend(indent(DefaultSQLRenderer.render(r), " ") for r in get_inline_references_for_sql(model)) + + if model._has_composite_pk(): + body.append( + " PRIMARY KEY (" + + ', '.join(f'"{c.name}"' for c in model.columns if c.pk) + + ')') + + return ',\n'.join(body) + + +def create_components(model: Table) -> str: + components = [comment_to_sql(model.comment)] if model.comment else [] + components.append(f'CREATE TABLE {get_full_name_for_sql(model)} (') + + body = create_body(model) + + components.append(body) + components.append(');') + components.extend('\n' + DefaultSQLRenderer.render(i) for i in model.indexes if not i.pk) + + return '\n'.join(components) + + +def render_column_notes(model: Table) -> str: + result = '' + for col in model.columns: + if col.note: + quoted_note = f"'{prepare_text_for_sql(col.note)}'" + note_sql = f'COMMENT ON COLUMN "{model.name}"."{col.name}" IS {quoted_note};' + result += f'\n\n{note_sql}' + return result + + +@DefaultSQLRenderer.renderer_for(Table) +def render_table(model: Table) -> str: + ''' + Returns full SQL for table definition: + + CREATE TABLE "countries" ( + "code" int PRIMARY KEY, + "name" varchar, + "continent_name" varchar + ); + + Also returns indexes if they were defined: + + CREATE INDEX ON "products" ("id", "name"); + ''' + result = create_components(model) + + if model.note: + result += f'\n\n{model.note.sql}' + + result += render_column_notes(model) + return result diff --git a/pydbml/renderer/sql/default/utils.py b/pydbml/renderer/sql/default/utils.py new file mode 100644 index 0000000..8befac3 --- /dev/null +++ b/pydbml/renderer/sql/default/utils.py @@ -0,0 +1,37 @@ +from typing import List, Dict, Union + +from pydbml.classes import Enum, Reference, Table +from pydbml.constants import MANY_TO_ONE, ONE_TO_MANY +from pydbml.tools import comment + + +def comment_to_sql(val: str) -> str: + return comment(val, '--') + + +def reorder_tables_for_sql(tables: List['Table'], refs: List['Reference']) -> List['Table']: + """ + Attempt to reorder the tables, so that they are defined in SQL before they are referenced by + inline foreign keys. + + Won't aid the rare cases of cross-references and many-to-many relations. + """ + + references: Dict[str, int] = {} + for ref in refs: + if ref.inline: + if ref.type == MANY_TO_ONE and ref.table1 is not None: + table_name = ref.table1.name + elif ref.type == ONE_TO_MANY and ref.table2 is not None: + table_name = ref.table2.name + else: # pragma: no cover + continue + references[table_name] = references.get(table_name, 0) + 1 + return sorted(tables, key=lambda t: references.get(t.name, 0), reverse=True) + + +def get_full_name_for_sql(model: Union[Table, Enum]) -> str: + if model.schema == 'public': + return f'"{model.name}"' + else: + return f'"{model.schema}"."{model.name}"' diff --git a/pydbml/tools.py b/pydbml/tools.py new file mode 100644 index 0000000..353136a --- /dev/null +++ b/pydbml/tools.py @@ -0,0 +1,54 @@ +import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + pass + + +def comment(val: str, comb: str) -> str: + return '\n'.join(f'{comb} {cl}' for cl in val.split('\n')) + '\n' + + +def indent(val: str, spaces=4) -> str: + if val == '': + return val + return ' ' * spaces + val.replace('\n', '\n' + ' ' * spaces) + + +def remove_bom(source: str) -> str: + if source and source[0] == '\ufeff': + source = source[1:] + return source + + +def strip_empty_lines(source: str) -> str: + """Remove empty lines or lines with just spaces from beginning and end.""" + pattern = re.compile(r'^([ \t]*\n)*(?P[\s\S]+?)(\n[ \t]*)*$') + return pattern.sub(r'\g', source) + + +def doublequote_string(source: str) -> str: + """Safely wrap a single-line string in double quotes""" + if '\n' in source: + raise ValueError(f'Multiline strings are not allowed: {source!r}') + result = source.strip('"').replace('"', '\\"') + return f'"{result}"' + + +def remove_indentation(source: str) -> str: + if not source: + return source + + pattern = re.compile(r'^\s*') + + lines = source.split('\n') + spaces = [] + for line in lines: + if line and not line.isspace(): + indent_match = pattern.search(line) + if indent_match is not None: # this is just for you mypy + spaces.append(len(indent_match[0])) + + indent = min(spaces) + lines = [l[indent:] for l in lines] + return '\n'.join(lines) diff --git a/setup.py b/setup.py index b5b9989..fc5cf27 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ -from setuptools import setup +from setuptools import setup, find_packages - -SHORT_DESCRIPTION = 'DBML syntax parser for Python' +SHORT_DESCRIPTION = 'Python parser and builder for DBML' try: with open('README.md', encoding='utf8') as readme: @@ -13,22 +12,20 @@ setup( name='pydbml', - python_requires='>=3.5', + python_requires='>=3.8', description=SHORT_DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type='text/markdown', - version='0.4.0', + version='1.2.0', author='Daniil Minukhin', author_email='ddddsa@gmail.com', url='https://github.com/Vanderhoof/PyDBML', - packages=['pydbml', 'pydbml.definitions'], + packages=find_packages(exclude=['test', 'test.*']), license='MIT', platforms='any', - install_requires=[ - 'pyparsing>=2.4.7', - ], + install_requires=['pyparsing>=3.0.0'], classifiers=[ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", "Environment :: Console", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", diff --git a/test.sh b/test.sh index 8fd1cdf..8e64632 100755 --- a/test.sh +++ b/test.sh @@ -1,3 +1,2 @@ -python3 -m doctest README.md &&\ - python3 -m unittest discover &&\ - mypy . --ignore-missing-imports +pytest --doctest-glob="*.md" &&\ + mypy pydbml --ignore-missing-imports diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..8ef74aa --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,138 @@ +from textwrap import dedent + +import pytest + +from pydbml import Database +from pydbml._classes.reference import Reference +from pydbml._classes.sticky_note import StickyNote +from pydbml.classes import Column, Enum, EnumItem, Note, Expression, Table, Index + + +@pytest.fixture +def db(): + return Database() + + +@pytest.fixture +def enum_item1(): + return EnumItem('en-US') + + +@pytest.fixture +def enum1(): + return Enum('product status', ('production', 'development')) + + +@pytest.fixture +def expression1() -> Expression: + return Expression('SUM(amount)') + + +@pytest.fixture +def simple_column() -> Column: + return Column( + name='id', + type='integer' + ) + + +@pytest.fixture +def simple_column_with_table(db: Database, table1: Table, simple_column: Column) -> Column: + table1.add_column(simple_column) + db.add(table1) + return simple_column + + +@pytest.fixture +def complex_column(enum1: Enum) -> Column: + return Column( + name='counter', + type=enum1, + pk=True, + autoinc=True, + unique=True, + not_null=True, + default=0, + comment='This is a counter column', + note=Note('This is a note for the column'), + properties={'foo': 'bar', 'baz': "qux\nqux"} + ) + + +@pytest.fixture +def complex_column_with_table(db: Database, table1: Table, complex_column: Column) -> Column: + table1.add_column(complex_column) + db.add(table1) + return complex_column + + +@pytest.fixture +def table1() -> Table: + return Table( + name='products', + columns=[ + Column('id', 'integer'), + Column('name', 'varchar'), + ] + ) + + +@pytest.fixture +def table2() -> Table: + return Table( + name='products', + columns=[ + Column('id', 'integer'), + Column('name', 'varchar'), + ] + ) + + +@pytest.fixture +def table3() -> Table: + return Table( + name='orders', + columns=[ + Column('id', 'integer'), + Column('product_id', 'integer'), + Column('price', 'float'), + ] + ) + +@pytest.fixture +def reference1(table2: Table, table3: Table) -> Reference: + return Reference( + type='>', + col1=[table3.columns[1]], + col2=[table2.columns[0]], + ) + + +@pytest.fixture +def index1(table1: Table) -> Index: + result = Index( + subjects=[table1.columns[1]] + ) + table1.add_index(result) + return result + + +@pytest.fixture +def note1(): + return Note('Simple note') + + +@pytest.fixture +def sticky_note1(): + return StickyNote(name='mynote', text='Simple note') + + +@pytest.fixture +def multiline_note(): + return Note( + dedent( + '''\ + This is a multiline note. + It has multiple lines.''' + ) + ) diff --git a/test/test_blueprints/__init__.py b/test/test_blueprints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_blueprints/test_column.py b/test/test_blueprints/test_column.py new file mode 100644 index 0000000..ab9a7cf --- /dev/null +++ b/test/test_blueprints/test_column.py @@ -0,0 +1,90 @@ +from unittest import TestCase +from unittest.mock import Mock + +from pydbml.classes import Column +from pydbml.classes import Enum +from pydbml.classes import EnumItem +from pydbml.classes import Note +from pydbml.database import Database +from pydbml.parser.blueprints import ColumnBlueprint +from pydbml.parser.blueprints import NoteBlueprint + + +class TestColumn(TestCase): + def test_build_minimal(self) -> None: + bp = ColumnBlueprint( + name='testcol', + type='varchar' + ) + result = bp.build() + self.assertIsInstance(result, Column) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.type, bp.type) + + def test_build_full(self) -> None: + bp = ColumnBlueprint( + name='id', + type='number', + unique=True, + not_null=True, + pk=True, + autoinc=True, + default=0, + note=NoteBlueprint(text='note text'), + comment='Col commment' + ) + result = bp.build() + self.assertIsInstance(result, Column) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.type, bp.type) + self.assertEqual(result.unique, bp.unique) + self.assertEqual(result.not_null, bp.not_null) + self.assertEqual(result.pk, bp.pk) + self.assertEqual(result.autoinc, bp.autoinc) + self.assertEqual(result.default, bp.default) + self.assertIsInstance(result.note, Note) + self.assertEqual(result.note.text, bp.note.text) + self.assertEqual(result.comment, bp.comment) + + def test_enum_type(self) -> None: + s = Database() + e = Enum( + 'myenum', + items=[ + EnumItem('i1'), + EnumItem('i2') + ] + ) + s.add(e) + parser = Mock() + parser.database = s + + bp = ColumnBlueprint( + name='testcol', + type='myenum' + ) + bp.parser = parser + result = bp.build() + self.assertIs(result.type, e) + + def test_enum_type_schema(self) -> None: + s = Database() + e = Enum( + 'myenum', + schema='myschema', + items=[ + EnumItem('i1'), + EnumItem('i2') + ] + ) + s.add(e) + parser = Mock() + parser.database = s + + bp = ColumnBlueprint( + name='testcol', + type='myschema.myenum' + ) + bp.parser = parser + result = bp.build() + self.assertIs(result.type, e) diff --git a/test/test_blueprints/test_enum.py b/test/test_blueprints/test_enum.py new file mode 100644 index 0000000..63a9fcf --- /dev/null +++ b/test/test_blueprints/test_enum.py @@ -0,0 +1,50 @@ +from unittest import TestCase + +from pydbml.classes import Enum +from pydbml.classes import EnumItem +from pydbml.classes import Note +from pydbml.parser.blueprints import EnumBlueprint +from pydbml.parser.blueprints import EnumItemBlueprint +from pydbml.parser.blueprints import NoteBlueprint + + +class TestEnumItemBlueprint(TestCase): + def test_build_minimal(self) -> None: + bp = EnumItemBlueprint( + name='Red' + ) + result = bp.build() + self.assertIsInstance(result, EnumItem) + self.assertEqual(result.name, bp.name) + + def test_build_full(self) -> None: + bp = EnumItemBlueprint( + name='Red', + note=NoteBlueprint(text='Note text'), + comment='Comment text' + ) + result = bp.build() + self.assertIsInstance(result, EnumItem) + self.assertEqual(result.name, bp.name) + self.assertIsInstance(result.note, Note) + self.assertEqual(result.note.text, bp.note.text) + self.assertEqual(result.comment, bp.comment) + + +class TestEnumBlueprint(TestCase): + def test_build(self) -> None: + bp = EnumBlueprint( + name='Colors', + items=[ + EnumItemBlueprint(name='Red'), + EnumItemBlueprint(name='Green'), + EnumItemBlueprint(name='Blue') + ], + comment='Comment text' + ) + result = bp.build() + self.assertIsInstance(result, Enum) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.comment, bp.comment) + for ei in result.items: + self.assertIsInstance(ei, EnumItem) diff --git a/test/test_blueprints/test_expression.py b/test/test_blueprints/test_expression.py new file mode 100644 index 0000000..18ac8dc --- /dev/null +++ b/test/test_blueprints/test_expression.py @@ -0,0 +1,12 @@ +from unittest import TestCase + +from pydbml.classes import Expression +from pydbml.parser.blueprints import ExpressionBlueprint + + +class TestNote(TestCase): + def test_build(self) -> None: + bp = ExpressionBlueprint(text='amount*2') + result = bp.build() + self.assertIsInstance(result, Expression) + self.assertEqual(result.text, bp.text) diff --git a/test/test_blueprints/test_index.py b/test/test_blueprints/test_index.py new file mode 100644 index 0000000..dc17440 --- /dev/null +++ b/test/test_blueprints/test_index.py @@ -0,0 +1,37 @@ +from unittest import TestCase + +from pydbml.classes import Index +from pydbml.classes import Note +from pydbml.parser.blueprints import IndexBlueprint +from pydbml.parser.blueprints import NoteBlueprint + + +class TestIndex(TestCase): + def test_build_minimal(self) -> None: + bp = IndexBlueprint( + subject_names=['a', 'b', 'c'] + ) + result = bp.build() + self.assertIsInstance(result, Index) + self.assertEqual(result.subject_names, []) + + def test_build_full(self) -> None: + bp = IndexBlueprint( + subject_names=['a', 'b', 'c'], + name='MyIndex', + unique=True, + type='hash', + pk=True, + note=NoteBlueprint(text='Note text'), + comment='Comment text' + ) + result = bp.build() + self.assertIsInstance(result, Index) + self.assertEqual(result.subject_names, []) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.unique, bp.unique) + self.assertEqual(result.type, bp.type) + self.assertEqual(result.pk, bp.pk) + self.assertIsInstance(result.note, Note) + self.assertEqual(result.note.text, bp.note.text) + self.assertEqual(result.comment, bp.comment) diff --git a/test/test_blueprints/test_note.py b/test/test_blueprints/test_note.py new file mode 100644 index 0000000..85028fa --- /dev/null +++ b/test/test_blueprints/test_note.py @@ -0,0 +1,46 @@ +from unittest import TestCase + +from pydbml.classes import Note +from pydbml.parser.blueprints import NoteBlueprint + + +class TestNote(TestCase): + def test_build(self) -> None: + bp = NoteBlueprint(text='Note text') + result = bp.build() + self.assertIsInstance(result, Note) + self.assertEqual(result.text, bp.text) + + def test_preformat_not_needed(self): + oneline = 'One line of note text' + multiline = 'Multiline\nnote\n\ntext' + long_line = 'Lorem ipsum dolor sit amet consectetur adipisicing elit. Aspernatur quidem adipisci, impedit, ut illum dolorum consequatur odio voluptate numquam ea itaque excepturi, a libero placeat corrupti. Amet beatae suscipit necessitatibus. Ea expedita explicabo iste quae rem aliquam minus cumque eveniet enim delectus, alias aut impedit quaerat quia ex, aliquid sint amet iusto rerum! Sunt deserunt ea saepe corrupti officiis. Assumenda.' + + bp = NoteBlueprint(text=oneline) + self.assertEqual(bp._preformat_text(), oneline) + bp = NoteBlueprint(text=multiline) + self.assertEqual(bp._preformat_text(), multiline) + bp = NoteBlueprint(text=long_line) + self.assertEqual(bp._preformat_text(), long_line) + + def test_preformat_needed(self): + uniform_indentation = ' line1\n line2\n line3' + varied_indentation = ' line1\n line2\n\n line3' + empty_lines = '\n\n\n\n\n\n\nline1\nline2\nline3\n\n\n\n\n\n\n' + empty_indented_lines = '\n \n\n \n\n line1\n line2\n line3\n\n\n\n \n\n\n' + + exptected = 'line1\nline2\nline3' + bp = NoteBlueprint(text=uniform_indentation) + self.assertEqual(bp._preformat_text(), exptected) + + exptected = 'line1\n line2\n\n line3' + bp = NoteBlueprint(text=varied_indentation) + self.assertEqual(bp._preformat_text(), exptected) + + exptected = 'line1\nline2\nline3' + bp = NoteBlueprint(text=empty_lines) + self.assertEqual(bp._preformat_text(), exptected) + + exptected = 'line1\nline2\nline3' + bp = NoteBlueprint(text=empty_indented_lines) + self.assertEqual(bp._preformat_text(), exptected) diff --git a/test/test_blueprints/test_project.py b/test/test_blueprints/test_project.py new file mode 100644 index 0000000..eebff86 --- /dev/null +++ b/test/test_blueprints/test_project.py @@ -0,0 +1,36 @@ +from unittest import TestCase + +from pydbml.classes import Note +from pydbml.classes import Project +from pydbml.parser.blueprints import NoteBlueprint +from pydbml.parser.blueprints import ProjectBlueprint + + +class TestProjectBlueprint(TestCase): + def test_build_minimal(self) -> None: + bp = ProjectBlueprint( + name='MyProject' + ) + result = bp.build() + self.assertIsInstance(result, Project) + self.assertEqual(result.name, bp.name) + + def test_build_full(self) -> None: + bp = ProjectBlueprint( + name='MyProject', + items={ + 'author': 'John Wick', + 'nickname': 'Baba Yaga', + 'reason': 'revenge' + }, + note=NoteBlueprint(text='note text'), + comment='comment text' + ) + result = bp.build() + self.assertIsInstance(result, Project) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.items, bp.items) + self.assertIsNot(result.items, bp.items) + self.assertIsInstance(result.note, Note) + self.assertEqual(result.note.text, bp.note.text) + self.assertEqual(result.comment, bp.comment) diff --git a/test/test_blueprints/test_reference.py b/test/test_blueprints/test_reference.py new file mode 100644 index 0000000..dd9bfd7 --- /dev/null +++ b/test/test_blueprints/test_reference.py @@ -0,0 +1,85 @@ +from unittest import TestCase +from unittest.mock import Mock + +from pydbml.classes import Column +from pydbml.classes import Reference +from pydbml.classes import Table +from pydbml.exceptions import ColumnNotFoundError +from pydbml.exceptions import TableNotFoundError +from pydbml.parser.blueprints import ReferenceBlueprint + + +class TestReferenceBlueprint(TestCase): + def test_build_minimal(self) -> None: + bp = ReferenceBlueprint( + type='>', + inline=True, + table1='table1', + col1='col1', + table2='table2', + col2='col2', + ) + + t1 = Table( + name='table1' + ) + c1 = Column(name='col1', type='Number') + t1.add_column(c1) + t2 = Table( + name='table2' + ) + c2 = Column(name='col2', type='Varchar') + t2.add_column(c2) + + with self.assertRaises(RuntimeError): + bp.build() + + parserMock = Mock() + parserMock.locate_table.side_effect = [t1, t2] + bp.parser = parserMock + result = bp.build() + self.assertIsInstance(result, Reference) + self.assertEqual(result.type, bp.type) + self.assertEqual(result.inline, bp.inline) + self.assertEqual(parserMock.locate_table.call_count, 2) + self.assertEqual(result.col1, [c1]) + self.assertEqual(result.col2, [c2]) + + def test_tables_and_cols_are_not_set(self) -> None: + bp = ReferenceBlueprint( + type='>', + inline=True, + table1=None, + col1='col1', + table2='table2', + col2='col2' + ) + with self.assertRaises(TableNotFoundError): + bp.build() + + bp.table1 = 'table1' + bp.table2 = None + with self.assertRaises(TableNotFoundError): + bp.build() + + bp.table2 = 'table2' + bp.col1 = None + with self.assertRaises(ColumnNotFoundError): + bp.build() + + bp.col1 = 'col1' + bp.col2 = None + with self.assertRaises(ColumnNotFoundError): + bp.build() + + def test_tables_and_cols_are_set(self) -> None: + bp = ReferenceBlueprint( + type='>', + inline=True, + table1='table1', + col1='col1', + table2='table2', + col2=None + ) + with self.assertRaises(ColumnNotFoundError): + bp.build() diff --git a/test/test_blueprints/test_sticky_note.py b/test/test_blueprints/test_sticky_note.py new file mode 100644 index 0000000..1f423a5 --- /dev/null +++ b/test/test_blueprints/test_sticky_note.py @@ -0,0 +1,53 @@ +from unittest import TestCase + +from pydbml._classes.sticky_note import StickyNote +from pydbml.parser.blueprints import StickyNoteBlueprint + +class TestNote(TestCase): + def test_build(self) -> None: + bp = StickyNoteBlueprint(name='mynote', text='Note text') + result = bp.build() + self.assertIsInstance(result, StickyNote) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.text, bp.text) + + def test_preformat_not_needed(self): + oneline = 'One line of note text' + multiline = 'Multiline\nnote\n\ntext' + long_line = 'Lorem ipsum dolor sit amet consectetur adipisicing elit. Aspernatur quidem adipisci, impedit, ut illum dolorum consequatur odio voluptate numquam ea itaque excepturi, a libero placeat corrupti. Amet beatae suscipit necessitatibus. Ea expedita explicabo iste quae rem aliquam minus cumque eveniet enim delectus, alias aut impedit quaerat quia ex, aliquid sint amet iusto rerum! Sunt deserunt ea saepe corrupti officiis. Assumenda.' + + bp = StickyNoteBlueprint(name='mynote', text=oneline) + self.assertEqual(bp.name, bp.name) + self.assertEqual(bp._preformat_text(), oneline) + bp = StickyNoteBlueprint(name='mynote', text=multiline) + self.assertEqual(bp.name, bp.name) + self.assertEqual(bp._preformat_text(), multiline) + bp = StickyNoteBlueprint(name='mynote', text=long_line) + self.assertEqual(bp.name, bp.name) + self.assertEqual(bp._preformat_text(), long_line) + + def test_preformat_needed(self): + uniform_indentation = ' line1\n line2\n line3' + varied_indentation = ' line1\n line2\n\n line3' + empty_lines = '\n\n\n\n\n\n\nline1\nline2\nline3\n\n\n\n\n\n\n' + empty_indented_lines = '\n \n\n \n\n line1\n line2\n line3\n\n\n\n \n\n\n' + + exptected = 'line1\nline2\nline3' + bp = StickyNoteBlueprint(name='mynote', text=uniform_indentation) + self.assertEqual(bp._preformat_text(), exptected) + self.assertEqual(bp.name, bp.name) + + exptected = 'line1\n line2\n\n line3' + bp = StickyNoteBlueprint(name='mynote', text=varied_indentation) + self.assertEqual(bp._preformat_text(), exptected) + self.assertEqual(bp.name, bp.name) + + exptected = 'line1\nline2\nline3' + bp = StickyNoteBlueprint(name='mynote', text=empty_lines) + self.assertEqual(bp._preformat_text(), exptected) + self.assertEqual(bp.name, bp.name) + + exptected = 'line1\nline2\nline3' + bp = StickyNoteBlueprint(name='mynote', text=empty_indented_lines) + self.assertEqual(bp._preformat_text(), exptected) + self.assertEqual(bp.name, bp.name) diff --git a/test/test_blueprints/test_table.py b/test/test_blueprints/test_table.py new file mode 100644 index 0000000..b12ce42 --- /dev/null +++ b/test/test_blueprints/test_table.py @@ -0,0 +1,128 @@ +from unittest import TestCase + +from pydbml.classes import Column +from pydbml.classes import Expression +from pydbml.classes import Index +from pydbml.classes import Note +from pydbml.classes import Table +from pydbml.exceptions import ColumnNotFoundError +from pydbml.parser.blueprints import ColumnBlueprint +from pydbml.parser.blueprints import ExpressionBlueprint +from pydbml.parser.blueprints import IndexBlueprint +from pydbml.parser.blueprints import NoteBlueprint +from pydbml.parser.blueprints import ReferenceBlueprint +from pydbml.parser.blueprints import TableBlueprint + + +class TestTable(TestCase): + def test_build_minimal(self) -> None: + bp = TableBlueprint(name='TestTable') + result = bp.build() + self.assertIsInstance(result, Table) + self.assertEqual(result.name, bp.name) + + def test_build_full_simple(self) -> None: + bp = TableBlueprint( + name='TestTable', + alias='TestAlias', + note=NoteBlueprint(text='Note text'), + header_color='#ccc', + comment='comment text' + ) + result = bp.build() + self.assertIsInstance(result, Table) + self.assertEqual(result.name, bp.name) + self.assertEqual(result.alias, bp.alias) + self.assertIsInstance(result.note, Note) + self.assertEqual(result.note.text, bp.note.text) + self.assertEqual(result.header_color, bp.header_color) + self.assertEqual(result.comment, bp.comment) + + def test_with_columns(self) -> None: + bp = TableBlueprint( + name='TestTable', + columns=[ + ColumnBlueprint(name='id', type='Integer', not_null=True, autoinc=True), + ColumnBlueprint(name='name', type='Varchar') + ] + ) + result = bp.build() + self.assertIsInstance(result, Table) + self.assertEqual(result.name, bp.name) + for col in result.columns: + self.assertIsInstance(col, Column) + + def test_with_indexes(self) -> None: + bp = TableBlueprint( + name='TestTable', + columns=[ + ColumnBlueprint(name='id', type='Integer', not_null=True, autoinc=True), + ColumnBlueprint(name='name', type='Varchar') + ], + indexes=[ + IndexBlueprint(subject_names=['name', 'id'], unique=True), + IndexBlueprint(subject_names=['id', ExpressionBlueprint('name*2')], name='ExprIndex') + ] + ) + result = bp.build() + self.assertIsInstance(result, Table) + self.assertEqual(result.name, bp.name) + for col in result.columns: + self.assertIsInstance(col, Column) + for ind in result.indexes: + self.assertIsInstance(ind, Index) + self.assertIsInstance(result.indexes[1].subjects[1], Expression) + + def test_bad_index(self) -> None: + bp = TableBlueprint( + name='TestTable', + columns=[ + ColumnBlueprint(name='id', type='Integer', not_null=True, autoinc=True), + ColumnBlueprint(name='name', type='Varchar') + ], + indexes=[ + IndexBlueprint(subject_names=['name', 'id'], unique=True), + IndexBlueprint(subject_names=['wrong', '(name*2)'], name='ExprIndex') + ] + ) + with self.assertRaises(ColumnNotFoundError): + bp.build() + + def test_get_reference_blueprints(self) -> None: + bp = TableBlueprint( + name='TestTable', + columns=[ + ColumnBlueprint( + name='id', + type='Integer', + not_null=True, + autoinc=True, + ref_blueprints=[ + ReferenceBlueprint( + type='<', + inline=True, + table2='AnotherTable', + col2=['AnotherCol']) + ] + ), + ColumnBlueprint( + name='name', + type='Varchar', + ref_blueprints=[ + ReferenceBlueprint( + type='>', + inline=True, + table2='YetAnotherTable', + col2=['YetAnotherCol']) + ] + ) + ] + ) + result = bp.build() + self.assertIsInstance(result, Table) + self.assertEqual(result.name, bp.name) + ref_bps = bp.get_reference_blueprints() + self.assertEqual(ref_bps[0].table1, result.name) + self.assertEqual(ref_bps[0].col1, 'id') + self.assertEqual(ref_bps[1].table1, result.name) + self.assertEqual(ref_bps[1].col1, 'name') diff --git a/test/test_blueprints/test_table_group.py b/test/test_blueprints/test_table_group.py new file mode 100644 index 0000000..763ebf8 --- /dev/null +++ b/test/test_blueprints/test_table_group.py @@ -0,0 +1,71 @@ +from unittest import TestCase +from unittest.mock import Mock + +from pydbml.classes import Table +from pydbml.classes import TableGroup +from pydbml.exceptions import ValidationError +from pydbml.parser.blueprints import TableGroupBlueprint + + +class TestTableGroupBlueprint(TestCase): + def test_build(self) -> None: + bp = TableGroupBlueprint( + name='TestTableGroup', + items=['table1', 'table2'], + comment='Comment text' + ) + with self.assertRaises(RuntimeError): + bp.build() + + parserMock = Mock() + parserMock.locate_table.side_effect = [ + Table(name='table1'), + Table(name='table2') + ] + bp.parser = parserMock + result = bp.build() + self.assertIsInstance(result, TableGroup) + self.assertEqual(parserMock.locate_table.call_count, 2) + for i in result.items: + self.assertIsInstance(i, Table) + + def test_build_with_schema(self) -> None: + bp = TableGroupBlueprint( + name='TestTableGroup', + items=['myschema.table1', 'myschema.table2'], + comment='Comment text' + ) + with self.assertRaises(RuntimeError): + bp.build() + + parserMock = Mock() + parserMock.locate_table.side_effect = [ + Table(name='table1', schema='myschema'), + Table(name='table2', schema='myschema') + ] + bp.parser = parserMock + result = bp.build() + self.assertIsInstance(result, TableGroup) + locate_table_calls = parserMock.locate_table.call_args_list + self.assertEqual(len(locate_table_calls), 2) + self.assertEqual(locate_table_calls[0].args, ('myschema', 'table1')) + self.assertEqual(locate_table_calls[1].args, ('myschema', 'table2')) + for i in result.items: + self.assertIsInstance(i, Table) + + def test_duplicate_table(self) -> None: + bp = TableGroupBlueprint( + name='TestTableGroup', + items=['table1', 'table2', 'table1'], + comment='Comment text' + ) + + parserMock = Mock() + parserMock.locate_table.side_effect = [ + Table(name='table1'), + Table(name='table2'), + Table(name='table1') + ] + bp.parser = parserMock + with self.assertRaises(ValidationError): + bp.build() diff --git a/test/test_classes.py b/test/test_classes.py deleted file mode 100644 index 43e9fb0..0000000 --- a/test/test_classes.py +++ /dev/null @@ -1,385 +0,0 @@ -from unittest import TestCase - -from pydbml.classes import Column -from pydbml.classes import Enum -from pydbml.classes import EnumItem -from pydbml.classes import Index -from pydbml.classes import Note -from pydbml.classes import ReferenceBlueprint -from pydbml.classes import SQLOjbect -from pydbml.classes import Table -from pydbml.classes import TableReference -from pydbml.exceptions import AttributeMissingError -from pydbml.exceptions import ColumnNotFoundError -from pydbml.exceptions import DuplicateReferenceError - - -class TestDBMLObject(TestCase): - def test_check_attributes_for_sql(self) -> None: - o = SQLOjbect() - o.a1 = None - o.b1 = None - o.c1 = None - o.required_attributes = ('a1', 'b1') - with self.assertRaises(AttributeMissingError): - o.check_attributes_for_sql() - o.a1 = 1 - with self.assertRaises(AttributeMissingError): - o.check_attributes_for_sql() - o.b1 = 'a2' - o.check_attributes_for_sql() - - def test_comparison(self) -> None: - o1 = SQLOjbect() - o1.a1 = None - o1.b1 = 'c' - o1.c1 = 123 - o2 = SQLOjbect() - o2.a1 = None - o2.b1 = 'c' - o2.c1 = 123 - self.assertTrue(o1 == o2) - o1.a2 = True - self.assertFalse(o1 == o2) - - -# class TestReferenceBlueprint(TestCase): -# def test_basic_sql(self) -> None: -# r = ReferenceBlueprint( -# ReferenceBlueprint.MANY_TO_ONE, -# table1='bookings', -# col1='country', -# table2='ids', -# col2='id' -# ) -# expected = 'ALTER TABLE "bookings" ADD FOREIGN KEY ("country") REFERENCES "ids ("id");' -# self.assertEqual(r.sql, expected) -# r.type = ReferenceBlueprint.ONE_TO_ONE -# self.assertEqual(r.sql, expected) -# r.type = ReferenceBlueprint.ONE_TO_MANY -# expected2 = 'ALTER TABLE "ids" ADD FOREIGN KEY ("id") REFERENCES "bookings ("country");' -# self.assertEqual(r.sql, expected2) - -# def test_full(self) -> None: -# r = ReferenceBlueprint( -# ReferenceBlueprint.MANY_TO_ONE, -# name='refname', -# table1='bookings', -# col1='country', -# table2='ids', -# col2='id', -# on_update='cascade', -# on_delete='restrict' -# ) -# expected = 'ALTER TABLE "bookings" ADD CONSTRAINT "refname" FOREIGN KEY ("country") REFERENCES "ids ("id") ON UPDATE CASCADE ON DELETE RESTRICT;' -# self.assertEqual(r.sql, expected) -# r.type = ReferenceBlueprint.ONE_TO_ONE -# self.assertEqual(r.sql, expected) -# r.type = ReferenceBlueprint.ONE_TO_MANY -# expected2 = 'ALTER TABLE "ids" ADD CONSTRAINT "refname" FOREIGN KEY ("id") REFERENCES "bookings ("country") ON UPDATE CASCADE ON DELETE RESTRICT;' -# self.assertEqual(r.sql, expected2) - - -# class TestTableReference(TestCase): -# def test_basic_sql(self) -> None: -# r = TableReference(col='order_id', -# ref_table='orders', -# ref_col='id') -# expected = 'FOREIGN KEY ("order_id") REFERENCES "orders ("id")' -# self.assertEqual(r.sql, expected) - -# def test_full(self) -> None: -# r = TableReference(col='order_id', -# ref_table='orders', -# ref_col='id', -# name='refname', -# on_delete='set null', -# on_update='no action') -# expected = 'CONSTRAINT "refname" FOREIGN KEY ("order_id") REFERENCES "orders ("id") ON UPDATE NO ACTION ON DELETE SET NULL' -# self.assertEqual(r.sql, expected) - - -class TestColumn(TestCase): - def test_basic_sql(self) -> None: - r = Column(name='id', - type_='integer') - expected = '"id" integer' - self.assertEqual(r.sql, expected) - - def test_note(self) -> None: - n = Note('Column note') - r = Column(name='id', - type_='integer', - note=n) - expected = '"id" integer -- Column note' - self.assertEqual(r.sql, expected) - - def test_pk_autoinc(self) -> None: - r = Column(name='id', - type_='integer', - pk=True, - autoinc=True) - expected = '"id" integer PRIMARY KEY AUTOINCREMENT' - self.assertEqual(r.sql, expected) - - def test_unique_not_null(self) -> None: - r = Column(name='id', - type_='integer', - unique=True, - not_null=True) - expected = '"id" integer UNIQUE NOT NULL' - self.assertEqual(r.sql, expected) - - def test_default(self) -> None: - r = Column(name='order', - type_='integer', - default=0) - expected = '"order" integer DEFAULT 0' - self.assertEqual(r.sql, expected) - - def test_table_setter(self) -> None: - r1 = ReferenceBlueprint( - ReferenceBlueprint.MANY_TO_ONE, - name='refname', - table1='bookings', - col1='order_id', - table2='orders', - col2='order', - ) - r2 = ReferenceBlueprint( - ReferenceBlueprint.MANY_TO_ONE, - name='refname', - table1='purchases', - col1='order_id', - table2='orders', - col2='order', - ) - c = Column( - name='order', - type_='integer', - default=0, - ref_blueprints=[r1, r2] - ) - t = Table('orders') - c.table = t - self.assertEqual(c.table, t) - self.assertEqual(c.ref_blueprints[0].table1, t.name) - self.assertEqual(c.ref_blueprints[1].table1, t.name) - - -class TestIndex(TestCase): - def test_basic_sql(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - r = Index(subject_names=['id'], - table=t) - t.add_index(r) - expected = 'CREATE INDEX ON "products" ("id");' - self.assertEqual(r.sql, expected) - - def test_note(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - n = Note('Index note') - r = Index(subject_names=['id'], - table=t, - note=n) - t.add_index(r) - expected = 'CREATE INDEX ON "products" ("id"); -- Index note' - self.assertEqual(r.sql, expected) - - def test_unique_type_composite(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - t.add_column(Column('name', 'varchar')) - r = Index(subject_names=['id', 'name'], - table=t, - type_='hash', - unique=True) - t.add_index(r) - expected = 'CREATE UNIQUE INDEX ON "products" USING HASH ("id", "name");' - self.assertEqual(r.sql, expected) - - def test_pk(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - t.add_column(Column('name', 'varchar')) - r = Index(subject_names=['id', 'name'], - table=t, - pk=True) - t.add_index(r) - expected = 'PRIMARY KEY ("id", "name")' - self.assertEqual(r.sql, expected) - - def test_composite_with_expression(self) -> None: - t = Table('products') - t.add_column(Column('id', 'integer')) - r = Index(subject_names=['id', '(id*3)'], - table=t) - t.add_index(r) - self.assertEqual(r.subjects, [t['id'], '(id*3)']) - expected = 'CREATE INDEX ON "products" ("id", (id*3));' - self.assertEqual(r.sql, expected) - - -class TestTable(TestCase): - def test_one_column(self) -> None: - t = Table('products') - c = Column('id', 'integer') - t.add_column(c) - expected = 'CREATE TABLE "products" (\n "id" integer\n);\n' - self.assertEqual(t.sql, expected) - - def test_ref(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - r = TableReference(c2, t2, c21) - t.add_ref(r) - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar2, - FOREIGN KEY ("name") REFERENCES "names ("name_val") -); -''' - self.assertEqual(t.sql, expected) - - def test_duplicate_ref(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - r1 = TableReference(c2, t2, c21) - t.add_ref(r1) - r2 = TableReference(c2, t2, c21) - self.assertEqual(r1, r2) - with self.assertRaises(DuplicateReferenceError): - t.add_ref(r2) - - def test_note(self) -> None: - n = Note('Table note') - t = Table('products', note=n) - c = Column('id', 'integer') - t.add_column(c) - expected = 'CREATE TABLE "products" (\n -- Table note\n "id" integer\n);\n' - self.assertEqual(t.sql, expected) - - def test_ref_index(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - t2 = Table('names') - c21 = Column('name_val', 'varchar2') - t2.add_column(c21) - r = TableReference(c2, t2, c21) - t.add_ref(r) - i = Index(['id', 'name']) - t.add_index(i) - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar2, - FOREIGN KEY ("name") REFERENCES "names ("name_val") -); - -CREATE INDEX ON "products" ("id", "name"); -''' - self.assertEqual(t.sql, expected) - - def test_index_inline(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - i = Index(['id', 'name'], pk=True) - t.add_index(i) - expected = \ -'''CREATE TABLE "products" ( - "id" integer, - "name" varchar2, - PRIMARY KEY ("id", "name") -); -''' - self.assertEqual(t.sql, expected) - - def test_add_column(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - t.add_column(c1) - t.add_column(c2) - self.assertEqual(c1.table, t) - self.assertEqual(c2.table, t) - self.assertEqual(t.columns, [c1, c2]) - - def test_add_index(self) -> None: - t = Table('products') - c1 = Column('id', 'integer') - c2 = Column('name', 'varchar2') - i1 = Index(['id']) - i2 = Index(['name']) - t.add_column(c1) - t.add_column(c2) - t.add_index(i1) - t.add_index(i2) - self.assertEqual(i1.table, t) - self.assertEqual(i2.table, t) - self.assertEqual(t.indexes, [i1, i2]) - - def test_add_bad_index(self) -> None: - t = Table('products') - c = Column('id', 'integer') - i = Index(['id', 'name']) - t.add_column(c) - with self.assertRaises(ColumnNotFoundError): - t.add_index(i) - - -class TestEnum(TestCase): - def test_simple_enum(self) -> None: - items = [ - EnumItem('created'), - EnumItem('running'), - EnumItem('donef'), - EnumItem('failure'), - ] - e = Enum('job_status', items) - expected = \ -'''CREATE TYPE "job_status" AS ENUM ( - 'created', - 'running', - 'donef', - 'failure', -);''' - self.assertEqual(e.sql, expected) - - def test_notes(self) -> None: - n = Note('EnumItem note') - items = [ - EnumItem('created', note=n), - EnumItem('running'), - EnumItem('donef', note=n), - EnumItem('failure'), - ] - e = Enum('job_status', items) - expected = \ -'''CREATE TYPE "job_status" AS ENUM ( - 'created', -- EnumItem note - 'running', - 'donef', -- EnumItem note - 'failure', -);''' - self.assertEqual(e.sql, expected) diff --git a/test/test_classes/__init__.py b/test/test_classes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_classes/test_base.py b/test/test_classes/test_base.py new file mode 100644 index 0000000..ebfae0e --- /dev/null +++ b/test/test_classes/test_base.py @@ -0,0 +1,34 @@ +from unittest import TestCase + +from pydbml._classes.base import SQLObject +from pydbml.exceptions import AttributeMissingError + + +class TestDBMLObject(TestCase): + def test_check_attributes_for_sql(self) -> None: + o = SQLObject() + o.a1 = None + o.b1 = None + o.c1 = None + o.required_attributes = ('a1', 'b1') + with self.assertRaises(AttributeMissingError): + o.check_attributes_for_sql() + o.a1 = 1 + with self.assertRaises(AttributeMissingError): + o.check_attributes_for_sql() + o.b1 = 'a2' + o.check_attributes_for_sql() + + def test_comparison(self) -> None: + o1 = SQLObject() + o1.a1 = None + o1.b1 = 'c' + o1.c1 = 123 + o2 = SQLObject() + o2.a1 = None + o2.b1 = 'c' + o2.c1 = 123 + self.assertTrue(o1 == o2) + o1.a2 = True + self.assertFalse(o1 == o2) + self.assertFalse(o1 == 123) diff --git a/test/test_classes/test_column.py b/test/test_classes/test_column.py new file mode 100644 index 0000000..dca0cd8 --- /dev/null +++ b/test/test_classes/test_column.py @@ -0,0 +1,166 @@ +from unittest import TestCase + +from pydbml.classes import Column +from pydbml.classes import Note +from pydbml.classes import Reference +from pydbml.classes import Table +from pydbml.database import Database +from pydbml.exceptions import TableNotFoundError + + +class TestColumn(TestCase): + def test_attributes(self) -> None: + name = 'name' + type = 'type' + unique = True + not_null = True + pk = True + autoinc = True + default = '1' + note = Note('note') + comment = 'comment' + col = Column( + name=name, + type=type, + unique=unique, + not_null=not_null, + pk=pk, + autoinc=autoinc, + default=default, + note=note, + comment=comment, + ) + self.assertEqual(col.name, name) + self.assertEqual(col.type, type) + self.assertEqual(col.unique, unique) + self.assertEqual(col.not_null, not_null) + self.assertEqual(col.pk, pk) + self.assertEqual(col.autoinc, autoinc) + self.assertEqual(col.default, default) + self.assertEqual(col.note.text, note.text) + self.assertEqual(col.comment, comment) + + def test_database_set(self) -> None: + col = Column('name', 'int') + table = Table('name') + self.assertIsNone(col.database) + table.add_column(col) + self.assertIsNone(col.database) + database = Database() + database.add(table) + self.assertIs(col.database, database) + + def test_pk_autoinc(self) -> None: + r = Column(name='id', + type='integer', + pk=True, + autoinc=True) + expected = '"id" integer PRIMARY KEY AUTOINCREMENT' + self.assertEqual(r.sql, expected) + + def test_unique_not_null(self) -> None: + r = Column(name='id', + type='integer', + unique=True, + not_null=True) + expected = '"id" integer UNIQUE NOT NULL' + self.assertEqual(r.sql, expected) + + def test_default(self) -> None: + r = Column(name='order', + type='integer', + default=0) + expected = '"order" integer DEFAULT 0' + self.assertEqual(r.sql, expected) + + def test_comment(self) -> None: + r = Column(name='id', + type='integer', + unique=True, + not_null=True, + comment="Column comment") + expected = \ +'''-- Column comment +"id" integer UNIQUE NOT NULL''' + self.assertEqual(r.sql, expected) + + def test_database(self): + c1 = Column(name='client_id', type='integer') + t1 = Table(name='products') + + self.assertIsNone(c1.database) + t1.add_column(c1) + self.assertIsNone(c1.database) + s = Database() + s.add(t1) + self.assertIs(c1.database, s) + + def test_get_refs(self) -> None: + c1 = Column(name='client_id', type='integer') + with self.assertRaises(TableNotFoundError): + c1.get_refs() + t1 = Table(name='products') + t1.add_column(c1) + c2 = Column(name='id', type='integer', autoinc=True, pk=True) + t2 = Table(name='clients') + t2.add_column(c2) + + ref = Reference(type='>', col1=c1, col2=c2, inline=True) + s = Database() + s.add(t1) + s.add(t2) + s.add(ref) + + self.assertEqual(c1.get_refs(), [ref]) + + def test_note_property(self): + note1 = Note('column note') + c1 = Column(name='client_id', type='integer') + c1.note = note1 + self.assertIs(c1.note.parent, c1) + + +class TestEqual: + @staticmethod + def test_other_type() -> None: + c1 = Column('name', 'VARCHAR2') + assert c1 != 'name' + + @staticmethod + def test_different_tables() -> None: + t1 = Table('table1', columns=[Column('name', 'VARCHAR2')]) + t2 = Table('table2', columns=[Column('name', 'VARCHAR2')]) + assert t1.columns[0] != t2.columns[0] + + @staticmethod + def test_same_table() -> None: + t1 = Table('table1', columns=[Column('name', 'VARCHAR2')]) + t2 = Table('table1', columns=[Column('name', 'VARCHAR2')]) + assert t1.columns[0] == t2.columns[0] + + @staticmethod + def test_same_column() -> None: + c1 = Column('name', 'VARCHAR2') + assert c1 == c1 + + @staticmethod + def test_table_not_set() -> None: + c1 = Column('name', 'VARCHAR2') + c2 = Column('name', 'VARCHAR2') + assert c1 == c2 + + @staticmethod + def test_ont_table_not_set() -> None: + c1 = Column('name', 'VARCHAR2') + c2 = Column('name', 'VARCHAR2') + t1 = Table('table1') + c1.table = t1 + assert c1 != c2 + + c1.table, c2.table = None, t1 + assert c1 != c2 + + +def test_repr() -> None: + c1 = Column('name', 'VARCHAR2') + assert repr(c1) == "" \ No newline at end of file diff --git a/test/test_classes/test_enum.py b/test/test_classes/test_enum.py new file mode 100644 index 0000000..fb053fa --- /dev/null +++ b/test/test_classes/test_enum.py @@ -0,0 +1,54 @@ +from pydbml.classes import Enum +from pydbml.classes import EnumItem +from pydbml.classes import Note +from unittest import TestCase + + +class TestEnumItem(TestCase): + def test_note_property(self): + note1 = Note('enum item note') + ei = EnumItem('en-US', note='preferred', comment='EnumItem comment') + ei.note = note1 + self.assertIs(ei.note.parent, ei) + + +class TestEnum(TestCase): + def test_getitem(self) -> None: + ei = EnumItem('created') + items = [ + EnumItem('running'), + ei, + EnumItem('donef'), + EnumItem('failure'), + ] + e = Enum('job_status', items) + self.assertIs(e[1], ei) + with self.assertRaises(IndexError): + e[22] + with self.assertRaises(TypeError): + e['abc'] + + def test_iter(self) -> None: + ei1 = EnumItem('created') + ei2 = EnumItem('running') + ei3 = EnumItem('donef') + ei4 = EnumItem('failure') + items = [ + ei1, + ei2, + ei3, + ei4, + ] + e = Enum('job_status', items) + + for i1, i2 in zip(e, [ei1, ei2, ei3, ei4]): + self.assertIs(i1, i2) + + +def test_repr(enum_item1: EnumItem) -> None: + assert repr(enum_item1) == "" + + +def test_str() -> None: + ei = EnumItem('en-US') + assert str(ei) == 'en-US' diff --git a/test/test_classes/test_expression.py b/test/test_classes/test_expression.py new file mode 100644 index 0000000..171ac9e --- /dev/null +++ b/test/test_classes/test_expression.py @@ -0,0 +1,11 @@ +from unittest import TestCase + +from pydbml.classes import Expression + + +def test_str(expression1: Expression) -> None: + assert str(expression1) == 'SUM(amount)' + + +def test_repr(expression1: Expression) -> None: + assert repr(expression1) == "Expression('SUM(amount)')" diff --git a/test/test_classes/test_index.py b/test/test_classes/test_index.py new file mode 100644 index 0000000..25db20d --- /dev/null +++ b/test/test_classes/test_index.py @@ -0,0 +1,21 @@ +from pydbml.classes import Column +from pydbml.classes import Index +from pydbml.classes import Note +from pydbml.classes import Table + + +def test_note_property(): + note1 = Note('column note') + t = Table('products') + c = Column('id', 'integer') + i = Index(subjects=[c]) + i.note = note1 + assert i.note.parent is i + + +def test_repr(index1: Index) -> None: + assert repr(index1) == "" + + +def test_str(index1: Index) -> None: + assert str(index1) == 'Index(products[name])' diff --git a/test/test_classes/test_note.py b/test/test_classes/test_note.py new file mode 100644 index 0000000..00b700a --- /dev/null +++ b/test/test_classes/test_note.py @@ -0,0 +1,23 @@ +from pydbml.classes import Note + + +def test_init_types(): + n1 = Note('My note text') + n2 = Note(3) + n3 = Note([1, 2, 3]) + n4 = Note(None) + n5 = Note(n1) + + assert n1.text == 'My note text' + assert n2.text == '3' + assert n3.text == '[1, 2, 3]' + assert n4.text == '' + assert n5.text == 'My note text' + + +def test_str(note1: Note) -> None: + assert str(note1) == 'Simple note' + + +def test_repr(note1: Note) -> None: + assert repr(note1) == "Note('Simple note')" diff --git a/test/test_classes/test_project.py b/test/test_classes/test_project.py new file mode 100644 index 0000000..3442b3a --- /dev/null +++ b/test/test_classes/test_project.py @@ -0,0 +1,15 @@ +from pydbml.classes import Note +from pydbml.classes import Project + + +def test_note_property(): + note1 = Note('column note') + p = Project('myproject') + p.note = note1 + assert p.note.parent is p + + +def test_repr() -> None: + project = Project('myproject') + assert repr(project) == "" + diff --git a/test/test_classes/test_reference.py b/test/test_classes/test_reference.py new file mode 100644 index 0000000..5e7cc14 --- /dev/null +++ b/test/test_classes/test_reference.py @@ -0,0 +1,130 @@ +from unittest import TestCase + +from pydbml.classes import Column +from pydbml.classes import Reference +from pydbml.classes import Table +from pydbml.exceptions import DBMLError +from pydbml.exceptions import TableNotFoundError +from pydbml.renderer.sql.default.reference import validate_for_sql + + +class TestReference(TestCase): + def test_table1(self): + t = Table('products') + c1 = Column('name', 'varchar2') + t2 = Table('names') + c2 = Column('name_val', 'varchar2') + t2.add_column(c2) + ref = Reference('>', c1, c2) + self.assertIsNone(ref.table1) + t.add_column(c1) + self.assertIs(ref.table1, t) + + def test_join_table(self) -> None: + t1 = Table('books') + c11 = Column('id', 'integer', pk=True) + c12 = Column('author', 'varchar') + t1.add_column(c11) + t1.add_column(c12) + t2 = Table('authors') + c21 = Column('id', 'integer', pk=True) + c22 = Column('name', 'varchar') + t2.add_column(c21) + t2.add_column(c22) + ref0 = Reference('>', [c11], [c21]) + ref = Reference('<>', [c11, c12], [c21, c22]) + + self.assertIsNone(ref0.join_table) + self.assertEqual(ref.join_table.name, 'books_authors') + self.assertEqual(len(ref.join_table.columns), 4) + + def test_join_table_none(self) -> None: + t1 = Table('books') + c11 = Column('id', 'integer', pk=True) + c12 = Column('author', 'varchar') + t1.add_column(c11) + t1.add_column(c12) + t2 = Table('authors') + c21 = Column('id', 'integer', pk=True) + c22 = Column('name', 'varchar') + t2.add_column(c21) + t2.add_column(c22) + ref = Reference('<>', [c11], [c21]) + + _table1 = ref.table1 + ref.col1[0].table = None + with self.assertRaises(TableNotFoundError): + ref.join_table + + ref.col1[0].table = _table1 + ref.col2[0].table = None + with self.assertRaises(TableNotFoundError): + ref.join_table + + +class TestReferenceInline(TestCase): + def test_validate_different_tables(self): + t1 = Table('products') + c11 = Column('id', 'integer') + c12 = Column('name', 'varchar2') + t1.add_column(c11) + t1.add_column(c12) + t2 = Table('names') + c21 = Column('name_val', 'varchar2') + t2.add_column(c21) + ref = Reference( + '<', + [c12, c21], + [c21], + name='nameref', + comment='Reference comment\nmultiline', + on_update='CASCADE', + on_delete='SET NULL', + inline=True + ) + with self.assertRaises(DBMLError): + ref._validate() + + ref = Reference( + '<', + [c11, c12], + [c21, c12], + name='nameref', + comment='Reference comment\nmultiline', + on_update='CASCADE', + on_delete='SET NULL', + inline=True + ) + with self.assertRaises(DBMLError): + ref._validate() + + def test_validate_no_table(self): + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + c3 = Column('age', 'number') + c4 = Column('active', 'boolean') + ref1 = Reference( + '<', + c1, + c2 + ) + with self.assertRaises(TableNotFoundError): + validate_for_sql(ref1) + table = Table('name') + table.add_column(c1) + with self.assertRaises(TableNotFoundError): + validate_for_sql(ref1) + table.delete_column(c1) + + ref2 = Reference( + '<', + [c1, c2], + [c3, c4] + ) + with self.assertRaises(TableNotFoundError): + validate_for_sql(ref2) + table = Table('name') + table.add_column(c1) + table.add_column(c2) + with self.assertRaises(TableNotFoundError): + validate_for_sql(ref2) diff --git a/test/test_classes/test_sticky_note.py b/test/test_classes/test_sticky_note.py new file mode 100644 index 0000000..24acb44 --- /dev/null +++ b/test/test_classes/test_sticky_note.py @@ -0,0 +1,28 @@ +from pydbml._classes.sticky_note import StickyNote + + +def test_init_types(): + n1 = StickyNote('mynote', 'My note text') + n2 = StickyNote('mynote', 3) + n3 = StickyNote('mynote', [1, 2, 3]) + n4 = StickyNote('mynote', None) + + assert n1.text == 'My note text' + assert n2.text == '3' + assert n3.text == '[1, 2, 3]' + assert n4.text == '' + assert n1.name == n2.name == n3.name == n4.name == 'mynote' + + +def test_str(sticky_note1: StickyNote) -> None: + assert str(sticky_note1) == "StickyNote('mynote', 'Simple note')" + + +def test_repr(sticky_note1: StickyNote) -> None: + assert repr(sticky_note1) == "" + + +def test_bool(sticky_note1: StickyNote) -> None: + assert bool(sticky_note1) is True + sticky_note1.text = '' + assert bool(sticky_note1) is False diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py new file mode 100644 index 0000000..345ee99 --- /dev/null +++ b/test/test_classes/test_table.py @@ -0,0 +1,201 @@ +from unittest import TestCase + +import pytest + +from pydbml.classes import Column +from pydbml.classes import Expression +from pydbml.classes import Index +from pydbml.classes import Note +from pydbml.classes import Reference +from pydbml.classes import Table +from pydbml.database import Database +from pydbml.exceptions import ColumnNotFoundError +from pydbml.exceptions import IndexNotFoundError +from pydbml.exceptions import UnknownDatabaseError + + +class TestTable(TestCase): + def test_schema(self) -> None: + t = Table('test') + self.assertEqual(t.schema, 'public') + t2 = Table('test', 'schema1') + self.assertEqual(t2.schema, 'schema1') + + def test_getitem(self) -> None: + t = Table('products') + c1 = Column('col1', 'integer') + c2 = Column('col2', 'integer') + c3 = Column('col3', 'integer') + t.add_column(c1) + t.add_column(c2) + t.add_column(c3) + self.assertIs(t['col1'], c1) + self.assertIs(t[1], c2) + with self.assertRaises(IndexError): + t[22] + with self.assertRaises(TypeError): + t[None] + with self.assertRaises(ColumnNotFoundError): + t['wrong'] + + def test_init_with_columns(self) -> None: + t = Table( + 'products', + columns=( + Column('col1', 'integer'), + Column('col2', 'integer'), + Column('col3', 'integer'), + ) + ) + self.assertIs(t['col1'].table, t) + self.assertIs(t['col2'].table, t) + self.assertIs(t['col3'].table, t) + + def test_init_with_indexes(self) -> None: + c1 = Column('col1', 'integer') + c2 = Column('col2', 'integer') + c3 = Column('col3', 'integer') + t = Table( + 'products', + columns=[c1, c2, c3], + indexes=[Index(subjects=[c1])] + ) + self.assertIs(t.indexes[0].table, t) + + def test_get(self) -> None: + t = Table('products') + c1 = Column('col1', 'integer') + c2 = Column('col2', 'integer') + c3 = Column('col3', 'integer') + t.add_column(c1) + t.add_column(c2) + self.assertIs(t.get(0), c1) + self.assertIs(t.get('col2'), c2) + self.assertIsNone(t.get('wrong')) + self.assertIsNone(t.get(22)) + self.assertIs(t.get('wrong', c2), c2) + self.assertIs(t.get(22, c2), c2) + self.assertIs(t.get('wrong', c3), c3) + + def test_iter(self) -> None: + t = Table('products') + c1 = Column('col1', 'integer') + c2 = Column('col2', 'integer') + c3 = Column('col3', 'integer') + t.add_column(c1) + t.add_column(c2) + t.add_column(c3) + for i1, i2 in zip(t, [c1, c2, c3]): + self.assertIs(i1, i2) + + def test_add_column(self) -> None: + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + self.assertEqual(c1.table, t) + self.assertEqual(c2.table, t) + self.assertEqual(t.columns, [c1, c2]) + with self.assertRaises(TypeError): + t.add_column('wrong type') + + def test_delete_column(self) -> None: + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + t.add_column(c1) + t.add_column(c2) + t.delete_column(c1) + self.assertIsNone(c1.table) + self.assertNotIn(c1, t.columns) + t.delete_column(0) + self.assertIsNone(c2.table) + self.assertNotIn(c2, t.columns) + with self.assertRaises(ColumnNotFoundError): + t.delete_column(c2) + + def test_add_index(self) -> None: + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + i1 = Index([c1]) + i2 = Index([c2]) + t.add_column(c1) + t.add_column(c2) + t.add_index(i1) + t.add_index(i2) + self.assertEqual(i1.table, t) + self.assertEqual(i2.table, t) + self.assertEqual(t.indexes, [i1, i2]) + with self.assertRaises(TypeError): + t.add_index('wrong_type') + + def test_delete_index(self) -> None: + t = Table('products') + c1 = Column('id', 'integer') + c2 = Column('name', 'varchar2') + i1 = Index([c1]) + i2 = Index([c2]) + t.add_column(c1) + t.add_column(c2) + t.add_index(i1) + t.add_index(i2) + t.delete_index(0) + self.assertIsNone(i1.table) + self.assertNotIn(i1, t.indexes) + t.delete_index(i2) + self.assertIsNone(i2.table) + self.assertNotIn(i2, t.indexes) + with self.assertRaises(IndexNotFoundError): + t.delete_index(i1) + + def test_get_refs(self): + t = Table('products') + with self.assertRaises(UnknownDatabaseError): + t.get_refs() + c11 = Column('id', 'integer') + c12 = Column('name', 'varchar2') + t.add_column(c11) + t.add_column(c12) + t2 = Table('names') + c21 = Column('id', 'integer') + c22 = Column('name_val', 'varchar2') + t2.add_column(c21) + t2.add_column(c22) + s = Database() + s.add(t) + s.add(t2) + r1 = Reference('>', c12, c22) + r2 = Reference('-', c11, c21) + r3 = Reference('<', c11, c22) + s.add(r1) + s.add(r2) + s.add(r3) + self.assertEqual(t.get_refs(), [r1, r2, r3]) + self.assertEqual(t2.get_refs(), []) + + def test_note_property(self): + note1 = Note('table note') + t = Table(name='test') + t.note = note1 + self.assertIs(t.note.parent, t) + + +class TestAddIndex: + @staticmethod + def test_wrong_type(table1: Table) -> None: + with pytest.raises(TypeError): + table1.add_index('wrong_type') + + + @staticmethod + def test_column_not_in_table(table1: Table, table2: Table) -> None: + with pytest.raises(ColumnNotFoundError): + table1.add_index(Index([table2.columns[0]])) + + @staticmethod + def test_ok(table1: Table) -> None: + i = Index([table1.columns[0]]) + table1.add_index(i) + assert i.table is table1 diff --git a/test/test_classes/test_table_group.py b/test/test_classes/test_table_group.py new file mode 100644 index 0000000..a718fac --- /dev/null +++ b/test/test_classes/test_table_group.py @@ -0,0 +1,31 @@ +from unittest import TestCase + +from pydbml.classes import Table +from pydbml.classes import TableGroup + + +class TestTableGroup(TestCase): + def test_getitem(self) -> None: + merchants = Table('merchants') + countries = Table('countries') + customers = Table('customers') + tg = TableGroup( + 'mytg', + [merchants, countries, customers], + comment='My table group\nmultiline comment' + ) + self.assertIs(tg[1], countries) + with self.assertRaises(IndexError): + tg[22] + + def test_iter(self) -> None: + merchants = Table('merchants') + countries = Table('countries') + customers = Table('customers') + tg = TableGroup( + 'mytg', + [merchants, countries, customers], + comment='My table group\nmultiline comment' + ) + for i1, i2 in zip(tg, [merchants, countries, customers]): + self.assertIs(i1, i2) diff --git a/test/test_column.py b/test/test_column.py deleted file mode 100644 index 1b795af..0000000 --- a/test/test_column.py +++ /dev/null @@ -1,225 +0,0 @@ -from unittest import TestCase - -from pyparsing import ParseException -from pyparsing import ParseSyntaxException -from pyparsing import ParserElement - -from pydbml.definitions.column import column_setting -from pydbml.definitions.column import column_settings -from pydbml.definitions.column import column_type -from pydbml.definitions.column import constraint -from pydbml.definitions.column import default -from pydbml.definitions.column import table_column - - -ParserElement.setDefaultWhitespaceChars(' \t\r') - - -class TestColumnType(TestCase): - def test_simple(self) -> None: - val = 'int' - res = column_type.parseString(val, parseAll=True) - self.assertEqual(res[0], val) - - def test_quoted(self) -> None: - val = '"mytype"' - res = column_type.parseString(val, parseAll=True) - self.assertEqual(res[0], val) - - def test_expression(self) -> None: - val = 'varchar(255)' - res = column_type.parseString(val, parseAll=True) - self.assertEqual(res[0], val) - - def test_symbols(self) -> None: - val = '(*#^)' - with self.assertRaises(ParseException): - column_type.parseString(val, parseAll=True) - - def test_string(self) -> None: - val = "'mytype'" - with self.assertRaises(ParseException): - column_type.parseString(val, parseAll=True) - - -class TestDefault(TestCase): - def test_string(self) -> None: - val = "default: 'string'" - val2 = "default: \n\n'string'" - expected = 'string' - res = default.parseString(val, parseAll=True) - self.assertEqual(res[0], expected) - res = default.parseString(val2, parseAll=True) - self.assertEqual(res[0], expected) - - def test_expression(self) -> None: - expr1 = 'datetime.now()' - expr2 = 'datetime\nnow()' - val = f"default: `{expr1}`" - val2 = f"default: `{expr2}`" - val3 = f"default: ``" - res = default.parseString(val, parseAll=True) - self.assertEqual(res[0], f'({expr1})') - res = default.parseString(val2, parseAll=True) - self.assertEqual(res[0], f'({expr2})') - res = default.parseString(val3, parseAll=True) - self.assertEqual(res[0], '()') - - def test_bool(self) -> None: - vals = ['true', 'false', 'null'] - exps = [True, False, 'NULL'] - while len(vals) > 0: - res = default.parseString(f'default: {vals.pop()}', parseAll=True) - self.assertEqual(exps.pop(), res[0]) - - def test_numbers(self) -> None: - vals = [0, 17, 13.3, 2.0] - while len(vals) > 0: - cur = vals.pop() - res = default.parseString(f'default: {cur}', parseAll=True) - self.assertEqual((cur), res[0]) - - def test_wrong(self) -> None: - val = "default: now" - with self.assertRaises(ParseSyntaxException): - default.parseString(val, parseAll=True) - - -class TestColumnSetting(TestCase): - def test_pass(self) -> None: - vals = ['not null', - 'null', - 'primary key', - 'pk', - 'unique', - 'default: 123', - 'ref: > table.column'] - for val in vals: - column_setting.parseString(val, parseAll=True) - - def test_fail(self) -> None: - vals = ['wrong', - '`null`', - '"pk"'] - for val in vals: - with self.assertRaises(ParseException): - column_setting.parseString(val, parseAll=True) - - -class TestColumnSettings(TestCase): - def test_nulls(self) -> None: - res = column_settings.parseString('[NULL]', parseAll=True) - self.assertNotIn('not_null', res[0]) - res = column_settings.parseString('[NOT NULL]', parseAll=True) - self.assertTrue(res[0]['not_null']) - res = column_settings.parseString('[NULL, NOT NULL]', parseAll=True) - self.assertTrue(res[0]['not_null']) - res = column_settings.parseString('[NOT NULL, NULL]', parseAll=True) - self.assertNotIn('not_null', res[0]) - - def test_pk(self) -> None: - res = column_settings.parseString('[pk]', parseAll=True) - self.assertTrue(res[0]['pk']) - res = column_settings.parseString('[primary key]', parseAll=True) - self.assertTrue(res[0]['pk']) - res = column_settings.parseString('[primary key, pk]', parseAll=True) - self.assertTrue(res[0]['pk']) - - def test_unique_increment(self) -> None: - res = column_settings.parseString('[unique, increment]', parseAll=True) - self.assertTrue(res[0]['unique']) - self.assertTrue(res[0]['autoinc']) - - def test_refs(self) -> None: - res = column_settings.parseString('[ref: > table.column]', parseAll=True) - self.assertEqual(len(res[0]['ref_blueprints']), 1) - res = column_settings.parseString('[ref: - table.column, ref: < table2.column2]', parseAll=True) - self.assertEqual(len(res[0]['ref_blueprints']), 2) - - def test_note_default(self) -> None: - res = column_settings.parseString('[default: 123, note: "mynote"]', parseAll=True) - self.assertIn('note', res[0]) - self.assertEqual(res[0]['default'], 123) - - def test_wrong(self) -> None: - val = "[wrong]" - with self.assertRaises(ParseSyntaxException): - column_settings.parseString(val, parseAll=True) - - -class TestConstraint(TestCase): - def test_should_parse(self) -> None: - constraint.parseString('unique', parseAll=True) - constraint.parseString('pk', parseAll=True) - - def test_should_fail(self) -> None: - with self.assertRaises(ParseException): - constraint.parseString('wrong', parseAll=True) - - -class TestColumn(TestCase): - def test_no_settings(self) -> None: - val = 'address varchar(255)\n' - res = table_column.parseString(val, parseAll=True) - self.assertEqual(res[0].name, 'address') - self.assertEqual(res[0].type, 'varchar(255)') - - def test_with_constraint(self) -> None: - val = 'user_id integer unique\n' - res = table_column.parseString(val, parseAll=True) - self.assertEqual(res[0].name, 'user_id') - self.assertEqual(res[0].type, 'integer') - self.assertTrue(res[0].unique) - val2 = 'user_id integer pk unique\n' - res2 = table_column.parseString(val2, parseAll=True) - self.assertEqual(res2[0].name, 'user_id') - self.assertEqual(res2[0].type, 'integer') - self.assertTrue(res2[0].unique) - self.assertTrue(res2[0].pk) - - def test_with_settings(self) -> None: - val = "_test_ \"mytype\" [unique, not null, note: 'to include unit number']\n" - res = table_column.parseString(val, parseAll=True) - self.assertEqual(res[0].name, '_test_') - self.assertEqual(res[0].type, '\"mytype\"') - self.assertTrue(res[0].unique) - self.assertTrue(res[0].not_null) - self.assertTrue(res[0].note is not None) - - def test_settings_and_constraints(self) -> None: - val = "_test_ \"mytype\" unique pk [not null]\n" - res = table_column.parseString(val, parseAll=True) - self.assertEqual(res[0].name, '_test_') - self.assertEqual(res[0].type, '\"mytype\"') - self.assertTrue(res[0].unique) - self.assertTrue(res[0].not_null) - self.assertTrue(res[0].pk) - - def test_comment_above(self) -> None: - val = '//comment above\naddress varchar\n' - res = table_column.parseString(val, parseAll=True) - self.assertEqual(res[0].name, 'address') - self.assertEqual(res[0].type, 'varchar') - self.assertEqual(res[0].comment, 'comment above') - - def test_comment_after(self) -> None: - val = 'address varchar //comment after\n' - res = table_column.parseString(val, parseAll=True) - self.assertEqual(res[0].name, 'address') - self.assertEqual(res[0].type, 'varchar') - self.assertEqual(res[0].comment, 'comment after') - val2 = 'user_id integer pk unique //comment after\n' - res2 = table_column.parseString(val2, parseAll=True) - self.assertEqual(res2[0].name, 'user_id') - self.assertEqual(res2[0].type, 'integer') - self.assertTrue(res2[0].unique) - self.assertTrue(res2[0].pk) - self.assertEqual(res2[0].comment, 'comment after') - val3 = "_test_ \"mytype\" unique pk [not null] //comment after\n" - res3 = table_column.parseString(val3, parseAll=True) - self.assertEqual(res3[0].name, '_test_') - self.assertEqual(res3[0].type, '\"mytype\"') - self.assertTrue(res3[0].unique) - self.assertTrue(res3[0].not_null) - self.assertTrue(res3[0].pk) - self.assertEqual(res3[0].comment, 'comment after') diff --git a/test/test_data/dbml_schema_def.dbml b/test/test_data/dbml_schema_def.dbml new file mode 100644 index 0000000..a5d509a --- /dev/null +++ b/test/test_data/dbml_schema_def.dbml @@ -0,0 +1,64 @@ +Table "ecommerce"."users" as EU { + id int [pk] + name varchar + ejs job_status + ejs2 public.job_status + eg schemaB.gender + eg2 gender +} + +Table public.users { + id int [pk] + name varchar + pjs job_status + pjs2 public.job_status + pg schemaB.gender + pg2 gender +} + +Table products { + id int [pk] + name varchar +} + +Table schemaA.products as A { + id int [pk] + name varchar [ref: > EU.id] +} + +Table schemaA.locations { + id int [pk] + name varchar [ref: > users.id ] +} + +Ref: "public".users.id < EU.id + +Ref name_optional { + users.name < ecommerce.users.id +} + +TableGroup tablegroup_name { // tablegroup is case-insensitive. + public.products + users + ecommerce.users + A +} + +enum job_status { + created2 [note: 'abcdef'] + running2 + done2 + failure2 +} + +enum schemaB.gender { + man + woman + nonbinary +} + +enum gender { + man2 + woman2 + nonbinary2 +} diff --git a/test/test_data/docs/enum_definition.dbml b/test/test_data/docs/enum_definition.dbml index aeb07f7..b82579d 100644 --- a/test/test_data/docs/enum_definition.dbml +++ b/test/test_data/docs/enum_definition.dbml @@ -5,8 +5,18 @@ enum job_status { failure } +enum grade { + "A+" + "A" + "A-" + "Not Yet Set" +} + + + Table jobs { id integer status job_status + grade grade } diff --git a/test/test_data/docs/sticky_notes.dbml b/test/test_data/docs/sticky_notes.dbml new file mode 100644 index 0000000..7f03cf7 --- /dev/null +++ b/test/test_data/docs/sticky_notes.dbml @@ -0,0 +1,10 @@ +Note single_line_note { + 'This is a single line note' +} + +Note multiple_lines_note { +''' + This is a multiple lines note + This string can spans over multiple lines. +''' +} diff --git a/test/test_data/editing.dbml b/test/test_data/editing.dbml new file mode 100644 index 0000000..2aa1ed9 --- /dev/null +++ b/test/test_data/editing.dbml @@ -0,0 +1,35 @@ +Table "products" { + "id" int [pk] + "name" varchar + "merchant_id" int [not null] + "price" int + "status" "product status" + "created_at" datetime [default: `now()`] + + + Indexes { + (merchant_id, status) [name: "product_status"] + id [type: hash, unique] + } +} + +Enum "product status" { + "Out of Stock" + "In Stock" +} + +Ref:"merchants"."id" < "products"."merchant_id" + + +Table "merchants" { + "id" int [pk] + "merchant_name" varchar + "country_code" int + "created_at" varchar + "admin_id" int +} + +TableGroup g1 { + products + merchants +} diff --git a/test/test_data/integration1.dbml b/test/test_data/integration1.dbml new file mode 100644 index 0000000..b24c9d7 --- /dev/null +++ b/test/test_data/integration1.dbml @@ -0,0 +1,44 @@ +Project "my project" { + author: 'me' + reason: 'testing' +} + +Enum "level" { + "junior" + "middle" + "senior" +} + +Table "Employees" as "emp" { + "id" integer [pk, increment] + "name" varchar [note: 'Full employee name'] + "age" number + "level" level + "favorite_book_id" integer +} + +Table "books" { + "id" integer [pk, increment] + "title" varchar + "author" varchar + "country_id" integer +} + +Table "countries" { + "id" integer [ref: < "books"."country_id", pk, increment] + "name" varchar2 [unique] + + indexes { + name [unique] + `UPPER(name)` + } +} + +Ref { + "Employees"."favorite_book_id" > "books"."id" +} + +TableGroup "Unanimate" { + "books" + "countries" +} \ No newline at end of file diff --git a/test/test_data/integration1.sql b/test/test_data/integration1.sql new file mode 100644 index 0000000..f648fb1 --- /dev/null +++ b/test/test_data/integration1.sql @@ -0,0 +1,34 @@ +CREATE TYPE "level" AS ENUM ( + 'junior', + 'middle', + 'senior' +); + +CREATE TABLE "books" ( + "id" integer PRIMARY KEY AUTOINCREMENT, + "title" varchar, + "author" varchar, + "country_id" integer, + CONSTRAINT "Country Reference" FOREIGN KEY ("country_id") REFERENCES "countries" ("id") +); + +CREATE TABLE "Employees" ( + "id" integer PRIMARY KEY AUTOINCREMENT, + "name" varchar, + "age" number DEFAULT 0, + "level" level, + "favorite_book_id" integer +); + +COMMENT ON COLUMN "Employees"."name" IS 'Full employee name'; + +CREATE TABLE "countries" ( + "id" integer PRIMARY KEY AUTOINCREMENT, + "name" varchar2 UNIQUE +); + +CREATE UNIQUE INDEX ON "countries" ("name"); + +CREATE INDEX ON "countries" ((UPPER(name))); + +ALTER TABLE "Employees" ADD FOREIGN KEY ("favorite_book_id") REFERENCES "books" ("id"); \ No newline at end of file diff --git a/test/test_data/notes.dbml b/test/test_data/notes.dbml new file mode 100644 index 0000000..d944c26 --- /dev/null +++ b/test/test_data/notes.dbml @@ -0,0 +1,60 @@ +Project "my project" { + author: 'me' + reason: 'testing' + Note: ''' + # DBML - Database Markup Language + DBML (database markup language) is a simple, readable DSL language designed to define database structures. + + ## Benefits + + * It is simple, flexible and highly human-readable + * It is database agnostic, focusing on the essential database structure definition without worrying about the detailed syntaxes of each database + * Comes with a free, simple database visualiser at [dbdiagram.io](http://dbdiagram.io) + ''' +} + +Enum "level" { + "junior" [note: 'enum item note'] + "middle" + "senior" +} + +Table "orders" [headercolor: #fff] { + "id" int [pk, increment] + "user_id" int [unique, not null] + "status" orders_status [note: "test note"] + "created_at" varchar + Note: 'Simple one line note' +} + +Table "order_items" { + "order_id" int + "product_id" int + "quantity" int [default: 1] + Note: 'Lorem ipsum, dolor sit amet consectetur adipisicing elit. Doloremque exercitationem facere eos, quod error consectetur.' + indexes { + order_id [unique, Note: 'Index note'] + `ROUND(quantity)` + } +} + +Table "products" { + "id" int [pk] + "name" varchar + "merchant_id" int [not null] + "price" int + "status" "product status" + "created_at" datetime [default: `now()`] + Note { + '''Indented note which is actually a Markdown formated string: + + - List item 1 + - Another list item + + ```[python + def test(): + print('Hello world!') + return 1 + ```''' + } +} diff --git a/test/test_data/relationships_aliases.dbml b/test/test_data/relationships_aliases.dbml index 4e7265b..7e8589d 100644 --- a/test/test_data/relationships_aliases.dbml +++ b/test/test_data/relationships_aliases.dbml @@ -29,3 +29,11 @@ Table reviews2 as re2 { Table users2 as us2 { id integer } + +Table "alembic_version" { + "version_num" "character varying(32)" [not null] + +Indexes { + version_num [pk, name: "alembic_version_pk"] +} +} \ No newline at end of file diff --git a/test/test_database.py b/test/test_database.py new file mode 100644 index 0000000..095e75c --- /dev/null +++ b/test/test_database.py @@ -0,0 +1,374 @@ +import os + +from pathlib import Path +from unittest import TestCase +from unittest.mock import Mock + +from pydbml.classes import Column +from pydbml.classes import Enum +from pydbml.classes import EnumItem +from pydbml.classes import Project +from pydbml.classes import Reference +from pydbml.classes import Table +from pydbml.classes import TableGroup +from pydbml.database import Database +from pydbml.exceptions import DatabaseValidationError +from pydbml.constants import ONE_TO_MANY, MANY_TO_ONE, MANY_TO_MANY +from pydbml.renderer.sql.default.utils import reorder_tables_for_sql + +TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' + + +class TestDatabase(TestCase): + def test_add_table(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + database = Database() + res = database.add_table(t) + self.assertEqual(t.database, database) + self.assertIs(res, t) + self.assertIn(t, database.tables) + + def test_add_table_alias(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table', alias='myalias') + t.add_column(c) + database = Database() + database.add_table(t) + self.assertIsInstance(t.alias, str) + self.assertIs(database[t.alias], t) + + def test_add_table_alias_bad(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test', alias='myalias') + t.add_column(c) + database = Database() + database.add_table(t) + t2 = Table('test_table', alias='myalias') + with self.assertRaises(DatabaseValidationError): + database.add_table(t2) + self.assertIsNone(t2.database) + + def test_add_table_bad(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + database = Database() + database.add_table(t) + with self.assertRaises(DatabaseValidationError): + database.add_table(t) + t2 = Table('test_table') + with self.assertRaises(DatabaseValidationError): + database.add_table(t2) + + def test_delete_table(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table', alias='myalias') + t.add_column(c) + database = Database() + database.add_table(t) + res = database.delete_table(t) + self.assertIsNone(t.database, database) + self.assertIs(res, t) + self.assertNotIn(t, database.tables) + self.assertNotIn('test_table', database.table_dict) + self.assertNotIn('myalias', database.table_dict) + + def test_delete_missing_table(self) -> None: + t = Table('test_table') + database = Database() + with self.assertRaises(DatabaseValidationError): + database.delete_table(t) + self.assertIsNone(t.database, database) + + def test_add_reference(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + database = Database() + database.add_table(t) + c2 = Column('test2', 'integer') + t2 = Table('test_table2') + t2.add_column(c2) + database.add_table(t2) + ref = Reference('>', c, c2) + res = database.add_reference(ref) + self.assertEqual(ref.database, database) + self.assertIs(res, ref) + self.assertIn(ref, database.refs) + + def test_add_reference_bad(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + database = Database() + database.add_table(t) + c2 = Column('test2', 'integer') + t2 = Table('test_table2') + t2.add_column(c2) + database.add_table(t2) + ref = Reference('>', c, c2) + database.add_reference(ref) + with self.assertRaises(DatabaseValidationError): + database.add_reference(ref) + + c3 = Column('test', 'varchar', True) + t3 = Table('test_table') + t3.add_column(c3) + database3 = Database() + database3.add_table(t3) + c32 = Column('test2', 'integer') + t32 = Table('test_table2') + t32.add_column(c32) + database3.add_table(t32) + ref3 = Reference('>', c3, c32) + with self.assertRaises(DatabaseValidationError): + database.add_reference(ref3) + + def test_delete_reference(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + database = Database() + database.add_table(t) + c2 = Column('test2', 'integer') + t2 = Table('test_table2') + t2.add_column(c2) + database.add_table(t2) + ref = Reference('>', c, c2) + res = database.add_reference(ref) + res = database.delete_reference(ref) + self.assertIsNone(ref.database, database) + self.assertIs(res, ref) + self.assertNotIn(ref, database.refs) + + def test_delete_missing_reference(self) -> None: + c = Column('test', 'varchar', True) + t = Table('test_table') + t.add_column(c) + database = Database() + database.add_table(t) + c2 = Column('test2', 'integer') + t2 = Table('test_table2') + t2.add_column(c2) + database.add_table(t2) + ref = Reference('>', c, c2) + with self.assertRaises(DatabaseValidationError): + database.delete_reference(ref) + self.assertIsNone(ref.database) + + def test_add_enum(self) -> None: + e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) + database = Database() + res = database.add_enum(e) + self.assertEqual(e.database, database) + self.assertIs(res, e) + self.assertIn(e, database.enums) + + def test_add_enum_bad(self) -> None: + e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) + database = Database() + database.add_enum(e) + with self.assertRaises(DatabaseValidationError): + database.add_enum(e) + e2 = Enum('myenum', [EnumItem('a2'), EnumItem('b2')]) + with self.assertRaises(DatabaseValidationError): + database.add_enum(e2) + + def test_delete_enum(self) -> None: + e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) + database = Database() + database.add_enum(e) + res = database.delete_enum(e) + self.assertIsNone(e.database) + self.assertIs(res, e) + self.assertNotIn(e, database.enums) + + def test_delete_missing_enum(self) -> None: + e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) + database = Database() + with self.assertRaises(DatabaseValidationError): + database.delete_enum(e) + self.assertIsNone(e.database) + + def test_add_table_group(self) -> None: + t1 = Table('table1') + t2 = Table('table2') + tg = TableGroup('mytablegroup', [t1, t2]) + database = Database() + res = database.add_table_group(tg) + self.assertEqual(tg.database, database) + self.assertIs(res, tg) + self.assertIn(tg, database.table_groups) + + def test_add_table_group_bad(self) -> None: + t1 = Table('table1') + t2 = Table('table2') + tg = TableGroup('mytablegroup', [t1, t2]) + database = Database() + database.add_table_group(tg) + with self.assertRaises(DatabaseValidationError): + database.add_table_group(tg) + tg2 = TableGroup('mytablegroup', [t2]) + with self.assertRaises(DatabaseValidationError): + database.add_table_group(tg2) + + def test_delete_table_group(self) -> None: + t1 = Table('table1') + t2 = Table('table2') + tg = TableGroup('mytablegroup', [t1, t2]) + database = Database() + database.add_table_group(tg) + res = database.delete_table_group(tg) + self.assertIsNone(tg.database) + self.assertIs(res, tg) + self.assertNotIn(tg, database.table_groups) + + def test_delete_missing_table_group(self) -> None: + t1 = Table('table1') + t2 = Table('table2') + tg = TableGroup('mytablegroup', [t1, t2]) + database = Database() + with self.assertRaises(DatabaseValidationError): + database.delete_table_group(tg) + self.assertIsNone(tg.database) + + def test_add_project(self) -> None: + p = Project('myproject') + database = Database() + res = database.add_project(p) + self.assertEqual(p.database, database) + self.assertIs(res, p) + self.assertIs(database.project, p) + + def test_add_another_project(self) -> None: + p = Project('myproject') + database = Database() + database.add_project(p) + p2 = Project('anotherproject') + res = database.add_project(p2) + self.assertEqual(p2.database, database) + self.assertIs(res, p2) + self.assertIs(database.project, p2) + self.assertIsNone(p.database) + + def test_delete_project(self) -> None: + p = Project('myproject') + database = Database() + database.add_project(p) + res = database.delete_project() + self.assertIsNone(p.database, database) + self.assertIs(res, p) + self.assertIsNone(database.project) + + def test_delete_missing_project(self) -> None: + database = Database() + with self.assertRaises(DatabaseValidationError): + database.delete_project() + + def test_geititem(self) -> None: + t1 = Table('table1') + t2 = Table('table2', schema='myschema') + database = Database() + database.add_table(t1) + database.add_table(t2) + self.assertIs(database['public.table1'], t1) + self.assertIs(database['myschema.table2'], t2) + self.assertIs(database[0], t1) + self.assertIs(database[1], t2) + with self.assertRaises(TypeError): + database[None] + with self.assertRaises(IndexError): + database[2] + with self.assertRaises(KeyError): + database['wrong'] + + def test_iter(self) -> None: + t1 = Table('table1') + t2 = Table('table2') + database = Database() + database.add_table(t1) + database.add_table(t2) + self.assertEqual(list(iter(database)), [t1, t2]) + + def test_add(self) -> None: + t1 = Table('table1') + t2 = Table('table2') + tg = TableGroup('mytablegroup', [t1, t2]) + e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) + database = Database() + database.add(t1) + database.add(t2) + database.add(e) + database.add(tg) + self.assertIs(t1.database, database) + self.assertIs(t2.database, database) + self.assertIs(e.database, database) + self.assertIs(tg.database, database) + self.assertIn(t1, database.tables) + self.assertIn(t2, database.tables) + self.assertIn(tg, database.table_groups) + self.assertIn(e, database.enums) + + def test_add_bad(self) -> None: + class Test: + pass + t = Test() + database = Database() + with self.assertRaises(DatabaseValidationError): + database.add(t) + with self.assertRaises(AttributeError): + t.database + + def test_delete(self) -> None: + t1 = Table('table1') + c1 = Column('col1', 'int') + t1.add_column(c1) + t2 = Table('table2') + c2 = Column('col2', 'int') + t2.add_column(c2) + ref = Reference('>', [c1], [c2]) + tg = TableGroup('mytablegroup', [t1, t2]) + e = Enum('myenum', [EnumItem('a'), EnumItem('b')]) + p = Project('myproject') + database = Database() + database.add(t1) + database.add(t2) + database.add(e) + database.add(tg) + database.add(ref) + database.add(p) + + database.delete(t1) + database.delete(t2) + database.delete(e) + database.delete(tg) + database.delete(ref) + database.delete(p) + self.assertIsNone(t1.database) + self.assertIsNone(t2.database) + self.assertIsNone(e.database) + self.assertIsNone(tg.database) + self.assertIsNone(ref.database) + self.assertIsNone(p.database) + self.assertIsNone(database.project) + self.assertNotIn(t1, database.tables) + self.assertNotIn(t2, database.tables) + self.assertNotIn(tg, database.table_groups) + self.assertNotIn(e, database.enums) + self.assertNotIn(ref, database.refs) + + def test_delete_bad(self) -> None: + class Test: + pass + t = Test() + database = Database() + with self.assertRaises(DatabaseValidationError): + database.delete(t) + with self.assertRaises(AttributeError): + t.database + + +def test_repr() -> None: + assert repr(Database()) == "" diff --git a/test/test_definitions/__init__.py b/test/test_definitions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_definitions/test_column.py b/test/test_definitions/test_column.py new file mode 100644 index 0000000..335949a --- /dev/null +++ b/test/test_definitions/test_column.py @@ -0,0 +1,270 @@ +from textwrap import dedent +from unittest import TestCase + +from pyparsing import ParseException +from pyparsing import ParseSyntaxException +from pyparsing import ParserElement + +from pydbml.definitions.column import column_setting, table_column_with_properties +from pydbml.definitions.column import column_settings +from pydbml.definitions.column import column_type +from pydbml.definitions.column import constraint +from pydbml.definitions.column import default +from pydbml.definitions.column import table_column +from pydbml.parser.blueprints import ExpressionBlueprint + + +ParserElement.set_default_whitespace_chars(" \t\r") + + +class TestColumnType(TestCase): + def test_simple(self) -> None: + val = "int" + res = column_type.parse_string(val, parseAll=True) + self.assertEqual(res[0], val) + + def test_quoted(self) -> None: + val = '"mytype"' + res = column_type.parse_string(val, parseAll=True) + self.assertEqual(res[0], "mytype") + + def test_with_schema(self) -> None: + val = "myschema.mytype" + res = column_type.parse_string(val, parseAll=True) + self.assertEqual(res[0], val) + + def test_expression(self) -> None: + val = "varchar(255)" + res = column_type.parse_string(val, parseAll=True) + self.assertEqual(res[0], val) + + def test_array(self) -> None: + val = "int[]" + res = column_type.parse_string(val, parseAll=True) + self.assertEqual(res[0], val) + + def test_symbols(self) -> None: + val = "(*#^)" + with self.assertRaises(ParseException): + column_type.parse_string(val, parseAll=True) + + def test_string(self) -> None: + val = "'mytype'" + with self.assertRaises(ParseException): + column_type.parse_string(val, parseAll=True) + + +class TestDefault(TestCase): + def test_string(self) -> None: + val = "default: 'string'" + val2 = "default: \n\n'string'" + expected = "string" + res = default.parse_string(val, parseAll=True) + self.assertEqual(res[0], expected) + res = default.parse_string(val2, parseAll=True) + self.assertEqual(res[0], expected) + + def test_expression(self) -> None: + expr1 = "datetime.now()" + expr2 = "datetime\nnow()" + val = f"default: `{expr1}`" + val2 = f"default: `{expr2}`" + val3 = f"default: ``" + res = default.parse_string(val, parseAll=True) + self.assertIsInstance(res[0], ExpressionBlueprint) + self.assertEqual(res[0].text, expr1) + res = default.parse_string(val2, parseAll=True) + self.assertIsInstance(res[0], ExpressionBlueprint) + self.assertEqual(res[0].text, expr2) + res = default.parse_string(val3, parseAll=True) + self.assertIsInstance(res[0], ExpressionBlueprint) + self.assertEqual(res[0].text, "") + + def test_bool(self) -> None: + vals = ["true", "false", "null"] + exps = [True, False, "NULL"] + while len(vals) > 0: + res = default.parse_string(f"default: {vals.pop()}", parseAll=True) + self.assertEqual(exps.pop(), res[0]) + + def test_numbers(self) -> None: + vals = [0, 17, 13.3, 2.0] + while len(vals) > 0: + cur = vals.pop() + res = default.parse_string(f"default: {cur}", parseAll=True) + self.assertEqual((cur), res[0]) + + def test_wrong(self) -> None: + val = "default: now" + with self.assertRaises(ParseSyntaxException): + default.parse_string(val, parseAll=True) + + +class TestColumnSetting(TestCase): + def test_pass(self) -> None: + vals = [ + "not null", + "null", + "primary key", + "pk", + "unique", + "default: 123", + "ref: > table.column", + ] + for val in vals: + column_setting.parse_string(val, parseAll=True) + + def test_fail(self) -> None: + vals = ["wrong", "`null`", '"pk"'] + for val in vals: + with self.assertRaises(ParseException): + column_setting.parse_string(val, parseAll=True) + + +class TestColumnSettings(TestCase): + def test_nulls(self) -> None: + res = column_settings.parse_string("[NULL]", parseAll=True) + self.assertNotIn("not_null", res[0]) + res = column_settings.parse_string("[NOT NULL]", parseAll=True) + self.assertTrue(res[0]["not_null"]) + res = column_settings.parse_string("[NULL, NOT NULL]", parseAll=True) + self.assertTrue(res[0]["not_null"]) + res = column_settings.parse_string("[NOT NULL, NULL]", parseAll=True) + self.assertNotIn("not_null", res[0]) + + def test_pk(self) -> None: + res = column_settings.parse_string("[pk]", parseAll=True) + self.assertTrue(res[0]["pk"]) + res = column_settings.parse_string("[primary key]", parseAll=True) + self.assertTrue(res[0]["pk"]) + res = column_settings.parse_string("[primary key, pk]", parseAll=True) + self.assertTrue(res[0]["pk"]) + + def test_unique_increment(self) -> None: + res = column_settings.parse_string("[unique, increment]", parseAll=True) + self.assertTrue(res[0]["unique"]) + self.assertTrue(res[0]["autoinc"]) + + def test_refs(self) -> None: + res = column_settings.parse_string("[ref: > table.column]", parseAll=True) + self.assertEqual(len(res[0]["ref_blueprints"]), 1) + res = column_settings.parse_string( + "[ref: - table.column, ref: < table2.column2]", parseAll=True + ) + self.assertEqual(len(res[0]["ref_blueprints"]), 2) + + def test_note_default(self) -> None: + res = column_settings.parse_string( + '[default: 123, note: "mynote"]', parseAll=True + ) + self.assertIn("note", res[0]) + self.assertEqual(res[0]["default"], 123) + + def test_wrong(self) -> None: + val = "[wrong]" + with self.assertRaises(ParseSyntaxException): + column_settings.parse_string(val, parseAll=True) + + +class TestConstraint(TestCase): + def test_should_parse(self) -> None: + constraint.parse_string("unique", parseAll=True) + constraint.parse_string("pk", parseAll=True) + + def test_should_fail(self) -> None: + with self.assertRaises(ParseException): + constraint.parse_string("wrong", parseAll=True) + + +class TestColumn(TestCase): + def test_no_settings(self) -> None: + val = "address varchar(255)\n" + res = table_column.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "address") + self.assertEqual(res[0].type, "varchar(255)") + + def test_with_constraint(self) -> None: + val = "user_id integer unique\n" + res = table_column.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "user_id") + self.assertEqual(res[0].type, "integer") + self.assertTrue(res[0].unique) + val2 = "user_id integer pk unique\n" + res2 = table_column.parse_string(val2, parseAll=True) + self.assertEqual(res2[0].name, "user_id") + self.assertEqual(res2[0].type, "integer") + self.assertTrue(res2[0].unique) + self.assertTrue(res2[0].pk) + + def test_with_settings(self) -> None: + val = "_test_ \"mytype\" [unique, not null, note: 'to include unit number']\n" + res = table_column.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "_test_") + self.assertEqual(res[0].type, "mytype") + self.assertTrue(res[0].unique) + self.assertTrue(res[0].not_null) + self.assertTrue(res[0].note is not None) + + def test_multiline_settings(self) -> None: + val = dedent("""_test_ \"mytype\" [ + unique, + not null, + note: 'to include unit number' + ] + """) + res = table_column.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "_test_") + self.assertEqual(res[0].type, "mytype") + self.assertTrue(res[0].unique) + self.assertTrue(res[0].not_null) + self.assertTrue(res[0].note is not None) + + def test_enum_type_bad(self) -> None: + val = "_test_ myschema.mytype(12) [unique]\n" + with self.assertRaises(ParseException): + table_column.parse_string(val, parseAll=True) + + def test_settings_and_constraints(self) -> None: + val = '_test_ "mytype" unique pk [not null]\n' + res = table_column.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "_test_") + self.assertEqual(res[0].type, "mytype") + self.assertTrue(res[0].unique) + self.assertTrue(res[0].not_null) + self.assertTrue(res[0].pk) + + def test_comment_above(self) -> None: + val = "//comment above\naddress varchar\n" + res = table_column.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "address") + self.assertEqual(res[0].type, "varchar") + self.assertEqual(res[0].comment, "comment above") + + def test_comment_after(self) -> None: + val = "address varchar //comment after\n" + res = table_column.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "address") + self.assertEqual(res[0].type, "varchar") + self.assertEqual(res[0].comment, "comment after") + val2 = "user_id integer pk unique //comment after\n" + res2 = table_column.parse_string(val2, parseAll=True) + self.assertEqual(res2[0].name, "user_id") + self.assertEqual(res2[0].type, "integer") + self.assertTrue(res2[0].unique) + self.assertTrue(res2[0].pk) + self.assertEqual(res2[0].comment, "comment after") + val3 = '_test_ "mytype" unique pk [not null] //comment after\n' + res3 = table_column.parse_string(val3, parseAll=True) + self.assertEqual(res3[0].name, "_test_") + self.assertEqual(res3[0].type, "mytype") + self.assertTrue(res3[0].unique) + self.assertTrue(res3[0].not_null) + self.assertTrue(res3[0].pk) + self.assertEqual(res3[0].comment, "comment after") + + +def test_properties() -> None: + val = "address varchar(255) [unique, foo: 'bar', baz: '''123''']" + res = table_column_with_properties.parse_string(val, parseAll=True) + assert res[0].properties == {"foo": "bar", "baz": '123'} + diff --git a/test/test_common.py b/test/test_definitions/test_common.py similarity index 67% rename from test/test_common.py rename to test/test_definitions/test_common.py index 3376d8f..1107fb0 100644 --- a/test/test_common.py +++ b/test/test_definitions/test_common.py @@ -9,77 +9,86 @@ from pydbml.definitions.common import note_object -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') class TestComment(TestCase): def test_comment_endstring(self) -> None: val = '//test comment' - res = comment.parseString(val, parseAll=True) + res = comment.parse_string(val, parseAll=True) self.assertEqual(res[0], 'test comment') def test_comment_endline(self) -> None: val = '//test comment\n\n\n\n\n' - res = comment.parseString(val) + res = comment.parse_string(val) self.assertEqual(res[0], 'test comment') + def test_multiline_comment(self) -> None: + val = '/*test comment*/' + res = comment.parse_string(val) + self.assertEqual(res[0], 'test comment') + + val2 = '/*\nline1\nline2\nline3\n*/' + res2 = comment.parse_string(val2) + self.assertEqual(res2[0], '\nline1\nline2\nline3\n') + class Test_c(TestCase): def test_comment(self) -> None: val = '\n\n\n\n//comment line 1\n\n//comment line 2' - res = _c.parseString(val, parseAll=True) + res = _c.parse_string(val, parseAll=True) self.assertEqual(list(res), ['comment line 1', 'comment line 2']) class TestNote(TestCase): def test_single_quote(self) -> None: val = "note: 'test note'" - res = note.parseString(val, parseAll=True) + res = note.parse_string(val, parseAll=True) self.assertEqual(res[0].text, 'test note') def test_double_quote(self) -> None: val = 'note: \n "test note"' - res = note.parseString(val, parseAll=True) + res = note.parse_string(val, parseAll=True) self.assertEqual(res[0].text, 'test note') def test_multiline(self) -> None: val = "note: '''line1\nline2\nline3'''" - res = note.parseString(val, parseAll=True) + res = note.parse_string(val, parseAll=True) self.assertEqual(res[0].text, 'line1\nline2\nline3') def test_unclosed_quote(self) -> None: val = 'note: "test note' with self.assertRaises(ParseSyntaxException): - note.parseString(val, parseAll=True) + note.parse_string(val, parseAll=True) def test_not_allowed_multiline(self) -> None: val = "note: 'line1\nline2\nline3'" with self.assertRaises(ParseSyntaxException): - note.parseString(val, parseAll=True) + note.parse_string(val, parseAll=True) class TestNoteObject(TestCase): def test_single_quote(self) -> None: val = "note {'test note'}" - res = note_object.parseString(val, parseAll=True) + res = note_object.parse_string(val, parseAll=True) self.assertEqual(res[0].text, 'test note') def test_double_quote(self) -> None: val = 'note \n\n {\n\n"test note"\n\n}' - res = note_object.parseString(val, parseAll=True) + res = note_object.parse_string(val, parseAll=True) self.assertEqual(res[0].text, 'test note') def test_multiline(self) -> None: val = "note\n{ '''line1\nline2\nline3'''}" - res = note_object.parseString(val, parseAll=True) + res = note_object.parse_string(val, parseAll=True) self.assertEqual(res[0].text, 'line1\nline2\nline3') def test_unclosed_quote(self) -> None: val = 'note{ "test note}' with self.assertRaises(ParseSyntaxException): - note_object.parseString(val, parseAll=True) + note_object.parse_string(val, parseAll=True) def test_not_allowed_multiline(self) -> None: val = "note { 'line1\nline2\nline3' }" with self.assertRaises(ParseSyntaxException): - note_object.parseString(val, parseAll=True) + note_object.parse_string(val, parseAll=True) diff --git a/test/test_enum.py b/test/test_definitions/test_enum.py similarity index 69% rename from test/test_enum.py rename to test/test_definitions/test_enum.py index ab2a38b..0d33d07 100644 --- a/test/test_enum.py +++ b/test/test_definitions/test_enum.py @@ -8,49 +8,49 @@ from pydbml.definitions.enum import enum_settings -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') -class TestTableSettings(TestCase): +class TestEnumSettings(TestCase): def test_note(self) -> None: val = '[note: "note content"]' - enum_settings.parseString(val, parseAll=True) + enum_settings.parse_string(val, parseAll=True) def test_wrong(self) -> None: val = '[wrong]' with self.assertRaises(ParseSyntaxException): - enum_settings.parseString(val, parseAll=True) + enum_settings.parse_string(val, parseAll=True) class TestEnumItem(TestCase): def test_no_settings(self) -> None: val = 'student' - res = enum_item.parseString(val, parseAll=True) + res = enum_item.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'student') def test_settings(self) -> None: val = 'student [note: "our future, help us God"]' - res = enum_item.parseString(val, parseAll=True) + res = enum_item.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'student') self.assertEqual(res[0].note.text, 'our future, help us God') def test_comment_before(self) -> None: val = '//comment before\nstudent [note: "our future, help us God"]' - res = enum_item.parseString(val, parseAll=True) + res = enum_item.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'student') self.assertEqual(res[0].note.text, 'our future, help us God') self.assertEqual(res[0].comment, 'comment before') def test_comment_after(self) -> None: val = 'student [note: "our future, help us God"] //comment after' - res = enum_item.parseString(val, parseAll=True) + res = enum_item.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'student') self.assertEqual(res[0].note.text, 'our future, help us God') self.assertEqual(res[0].comment, 'comment after') def test_comment_both(self) -> None: val = '//comment before\nstudent [note: "our future, help us God"] //comment after' - res = enum_item.parseString(val, parseAll=True) + res = enum_item.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'student') self.assertEqual(res[0].note.text, 'our future, help us God') self.assertEqual(res[0].comment, 'comment after') @@ -59,19 +59,27 @@ def test_comment_both(self) -> None: class TestEnum(TestCase): def test_singe_item(self) -> None: val = 'enum members {\nstudent\n}' - res = enum.parseString(val, parseAll=True) + res = enum.parse_string(val, parseAll=True) self.assertEqual(len(res[0].items), 1) self.assertEqual(res[0].name, 'members') def test_several_items(self) -> None: val = 'enum members {janitor teacher\nstudent\nheadmaster\n}' - res = enum.parseString(val, parseAll=True) + res = enum.parse_string(val, parseAll=True) self.assertEqual(len(res[0].items), 4) self.assertEqual(res[0].name, 'members') + def test_schema(self) -> None: + val1 = 'enum members {janitor teacher\nstudent\nheadmaster\n}' + res1 = enum.parse_string(val1, parseAll=True) + self.assertEqual(res1[0].schema, 'public') + val2 = 'enum myschema.members {janitor teacher\nstudent\nheadmaster\n}' + res2 = enum.parse_string(val2, parseAll=True) + self.assertEqual(res2[0].schema, 'myschema') + def test_comment(self) -> None: val = '//comment before\nenum members {janitor teacher\nstudent\nheadmaster\n}' - res = enum.parseString(val, parseAll=True) + res = enum.parse_string(val, parseAll=True) self.assertEqual(len(res[0].items), 4) self.assertEqual(res[0].name, 'members') self.assertEqual(res[0].comment, 'comment before') @@ -79,4 +87,4 @@ def test_comment(self) -> None: def test_oneline(self) -> None: val = 'enum members {student}' with self.assertRaises(ParseSyntaxException): - enum.parseString(val, parseAll=True) + enum.parse_string(val, parseAll=True) diff --git a/test/test_definitions/test_generic.py b/test/test_definitions/test_generic.py new file mode 100644 index 0000000..9a78430 --- /dev/null +++ b/test/test_definitions/test_generic.py @@ -0,0 +1,24 @@ +from unittest import TestCase + +from pyparsing import ParserElement + +from pydbml.definitions.generic import expression_literal, expression +from pydbml.parser.blueprints import ExpressionBlueprint + + +ParserElement.set_default_whitespace_chars(' \t\r') + + +class TestExpressionLiteral(TestCase): + def test_expression_literal(self) -> None: + val = '`SUM(amount)`' + res = expression_literal.parse_string(val) + self.assertIsInstance(res[0], ExpressionBlueprint) + self.assertEqual(res[0].text, 'SUM(amount)') + +class TestExpression(TestCase): + def test_comma_separated_expression(self) -> None: + val = 'MAX, 3, "MAX", \'MAX\'' + expected = ['MAX', ',', '3', ',', '"MAX"', ',', "'MAX'"] + res = expression.parse_string(val, parseAll=True) + self.assertEqual(res.asList(), expected) diff --git a/test/test_index.py b/test/test_definitions/test_index.py similarity index 67% rename from test/test_index.py rename to test/test_definitions/test_index.py index f60dda9..b3239f4 100644 --- a/test/test_index.py +++ b/test/test_definitions/test_index.py @@ -12,85 +12,82 @@ from pydbml.definitions.index import indexes from pydbml.definitions.index import single_index_syntax from pydbml.definitions.index import subject +from pydbml.parser.blueprints import ExpressionBlueprint -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') class TestIndexType(TestCase): def test_correct(self) -> None: - val = 'Type: BTREE' - res = index_type.parseString(val, parseAll=True) - self.assertEqual(res['type'], 'btree') - val2 = 'type:\nhash' - res2 = index_type.parseString(val2, parseAll=True) - self.assertEqual(res2['type'], 'hash') + for val, expected in [ + ("Type: BTREE", "btree"), + ("type: hash", "hash"), + ("type: gist", "gist"), + ("TYPE:SPGiST", "spgist"), + ("type: GIN", "gin"), + ("Type:\tbRiN", "brin"), + ]: + res = index_type.parse_string(val, parseAll=True) + self.assertEqual(res["type"], expected) def test_incorrect(self) -> None: val = 'type: wrong' with self.assertRaises(ParseSyntaxException): - index_type.parseString(val, parseAll=True) + index_type.parse_string(val, parseAll=True) class TestIndexSetting(TestCase): def test_unique(self) -> None: val = 'unique' - res = index_setting.parseString(val, parseAll=True) + res = index_setting.parse_string(val, parseAll=True) self.assertEqual(res['unique'], 'unique') def test_type(self) -> None: val = 'type: btree' - res = index_setting.parseString(val, parseAll=True) + res = index_setting.parse_string(val, parseAll=True) self.assertEqual(res['type'], 'btree') def test_name(self) -> None: val = 'name: "index name"' - res = index_setting.parseString(val, parseAll=True) + res = index_setting.parse_string(val, parseAll=True) self.assertEqual(res['name'], 'index name') def test_wrong_name(self) -> None: val = 'name: index name' with self.assertRaises(ParseSyntaxException): - index_setting.parseString(val, parseAll=True) + index_setting.parse_string(val, parseAll=True) val2 = 'name:,' with self.assertRaises(ParseSyntaxException): - index_setting.parseString(val2, parseAll=True) + index_setting.parse_string(val2, parseAll=True) def test_note(self) -> None: val = 'note: "note text"' - res = index_setting.parseString(val, parseAll=True) + res = index_setting.parse_string(val, parseAll=True) self.assertEqual(res['note'].text, 'note text') class TestIndexSettings(TestCase): def test_unique(self) -> None: val = '[unique]' - res = index_settings.parseString(val, parseAll=True) + res = index_settings.parse_string(val, parseAll=True) self.assertTrue(res[0]['unique']) def test_name_type_multiline(self) -> None: val = '[\nname: "index name"\n,\ntype:\nbtree\n]' - res = index_settings.parseString(val, parseAll=True) - self.assertEqual(res[0]['type_'], 'btree') + res = index_settings.parse_string(val, parseAll=True) + self.assertEqual(res[0]['type'], 'btree') self.assertEqual(res[0]['name'], 'index name') def test_pk(self) -> None: val = '[\npk\n]' - res = index_settings.parseString(val, parseAll=True) + res = index_settings.parse_string(val, parseAll=True) self.assertTrue(res[0]['pk']) - def test_wrong_pk(self) -> None: - val = '[pk, name: "not allowed"]' - with self.assertRaises(ParseSyntaxException): - index_settings.parseString(val, parseAll=True) - val2 = '[note: "pk not allowed", pk]' - with self.assertRaises(ParseSyntaxException): - index_settings.parseString(val2, parseAll=True) - def test_all(self) -> None: val = '[type: hash, name: "index name", note: "index note", unique]' - res = index_settings.parseString(val, parseAll=True) - self.assertEqual(res[0]['type_'], 'hash') + res = index_settings.parse_string(val, parseAll=True) + self.assertEqual(res[0]['type'], 'hash') self.assertEqual(res[0]['name'], 'index name') self.assertEqual(res[0]['note'].text, 'index note') self.assertTrue(res[0]['unique']) @@ -99,49 +96,50 @@ def test_all(self) -> None: class TestSubject(TestCase): def test_name(self) -> None: val = 'my_column' - res = subject.parseString(val, parseAll=True) + res = subject.parse_string(val, parseAll=True) self.assertEqual(res[0], val) def test_expression(self) -> None: val = '`id*3`' - res = subject.parseString(val, parseAll=True) - self.assertEqual(res[0], '(id*3)') + res = subject.parse_string(val, parseAll=True) + self.assertIsInstance(res[0], ExpressionBlueprint) + self.assertEqual(res[0].text, 'id*3') def test_wrong(self) -> None: val = '12d*(' with self.assertRaises(ParseException): - subject.parseString(val, parseAll=True) + subject.parse_string(val, parseAll=True) class TestSingleIndex(TestCase): def test_no_settings(self) -> None: val = 'my_column' - res = single_index_syntax.parseString(val, parseAll=True) + res = single_index_syntax.parse_string(val, parseAll=True) self.assertEqual(res['subject'], val) def test_settings(self) -> None: val = 'my_column [unique]' - res = single_index_syntax.parseString(val, parseAll=True) + res = single_index_syntax.parse_string(val, parseAll=True) self.assertEqual(res['subject'], 'my_column') self.assertTrue(res['settings']['unique']) def test_settings_on_new_line(self) -> None: val = 'my_column\n[unique]' with self.assertRaises(ParseException): - single_index_syntax.parseString(val, parseAll=True) + single_index_syntax.parse_string(val, parseAll=True) class TestCompositeIndex(TestCase): def test_no_settings(self) -> None: val = '(my_column, my_another_column)' - res = composite_index_syntax.parseString(val, parseAll=True) + res = composite_index_syntax.parse_string(val, parseAll=True) self.assertIn('my_column', list(res['subject'])) self.assertIn('my_another_column', list(res['subject'])) self.assertEqual(len(res['subject']), 2) def test_settings(self) -> None: val = '(my_column, my_another_column) [unique]' - res = composite_index_syntax.parseString(val, parseAll=True) + res = composite_index_syntax.parse_string(val, parseAll=True) self.assertIn('my_column', list(res['subject'])) self.assertIn('my_another_column', list(res['subject'])) self.assertEqual(len(res['subject']), 2) @@ -150,60 +148,63 @@ def test_settings(self) -> None: def test_new_line(self) -> None: val = '(my_column,\nmy_another_column) [unique]' with self.assertRaises(ParseException): - composite_index_syntax.parseString(val, parseAll=True) + composite_index_syntax.parse_string(val, parseAll=True) val2 = '(my_column, my_another_column)\n[unique]' with self.assertRaises(ParseException): - composite_index_syntax.parseString(val2, parseAll=True) + composite_index_syntax.parse_string(val2, parseAll=True) class TestIndex(TestCase): def test_single(self) -> None: val = 'my_column' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column']) def test_expression(self) -> None: val = '(`id*3`)' - res = index.parseString(val, parseAll=True) - self.assertEqual(res[0].subject_names, ['(id*3)']) + res = index.parse_string(val, parseAll=True) + self.assertIsInstance(res[0].subject_names[0], ExpressionBlueprint) + self.assertEqual(res[0].subject_names[0].text, 'id*3') def test_composite(self) -> None: val = '(my_column, my_another_column)' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column', 'my_another_column']) def test_composite_with_expression(self) -> None: val = '(`id*3`, fieldname)' - res = index.parseString(val, parseAll=True) - self.assertEqual(res[0].subject_names, ['(id*3)', 'fieldname']) + res = index.parse_string(val, parseAll=True) + self.assertIsInstance(res[0].subject_names[0], ExpressionBlueprint) + self.assertEqual(res[0].subject_names[0].text, 'id*3') + self.assertEqual(res[0].subject_names[1], 'fieldname') def test_with_settings(self) -> None: val = '(my_column, my_another_column) [unique]' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column', 'my_another_column']) self.assertTrue(res[0].unique) def test_comment_above(self) -> None: val = '//comment above\nmy_column [unique]' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column']) self.assertTrue(res[0].unique) self.assertEqual(res[0].comment, 'comment above') def test_comment_after(self) -> None: val = 'my_column [unique] //comment after' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column']) self.assertTrue(res[0].unique) self.assertEqual(res[0].comment, 'comment after') val = 'my_column //comment after' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column']) self.assertEqual(res[0].comment, 'comment after') def test_both_comments(self) -> None: val = '//comment before\nmy_column [unique] //comment after' - res = index.parseString(val, parseAll=True) + res = index.parse_string(val, parseAll=True) self.assertEqual(res[0].subject_names, ['my_column']) self.assertTrue(res[0].unique) self.assertEqual(res[0].comment, 'comment after') @@ -224,10 +225,10 @@ def test_valid(self) -> None: (`id*3`,`getdate()`) (`id*3`,id) }''' - res = indexes.parseString(val) + res = indexes.parse_string(val) self.assertEqual(len(res), 8) def test_invalid(self) -> None: val = 'indexes {my_column' with self.assertRaises(ParseSyntaxException): - indexes.parseString(val) + indexes.parse_string(val) diff --git a/test/test_project.py b/test/test_definitions/test_project.py similarity index 80% rename from test/test_project.py rename to test/test_definitions/test_project.py index 650965a..c049a0a 100644 --- a/test/test_project.py +++ b/test/test_definitions/test_project.py @@ -7,36 +7,36 @@ from pydbml.definitions.project import project_field -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') class TestProjectField(TestCase): def test_ok(self) -> None: val = "field: 'value'" - project_field.parseString(val, parseAll=True) + project_field.parse_string(val, parseAll=True) def test_nok(self) -> None: val = "field: value" with self.assertRaises(ParseSyntaxException): - project_field.parseString(val, parseAll=True) + project_field.parse_string(val, parseAll=True) class TestProject(TestCase): def test_empty(self) -> None: val = 'project name {}' - res = project.parseString(val, parseAll=True) + res = project.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'name') def test_fields(self) -> None: val = "project name {field1: 'value1' field2: 'value2'}" - res = project.parseString(val, parseAll=True) + res = project.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].items['field1'], 'value1') self.assertEqual(res[0].items['field2'], 'value2') def test_fields_and_note(self) -> None: val = "project name {\nfield1: 'value1'\nfield2: 'value2'\nnote: 'note value'}" - res = project.parseString(val, parseAll=True) + res = project.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].items['field1'], 'value1') self.assertEqual(res[0].items['field2'], 'value2') @@ -44,7 +44,7 @@ def test_fields_and_note(self) -> None: def test_comment(self) -> None: val = "//comment before\nproject name {\nfield1: 'value1'\nfield2: 'value2'\nnote: 'note value'}" - res = project.parseString(val, parseAll=True) + res = project.parse_string(val, parseAll=True) self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].items['field1'], 'value1') self.assertEqual(res[0].items['field2'], 'value2') diff --git a/test/test_reference.py b/test/test_definitions/test_reference.py similarity index 78% rename from test/test_reference.py rename to test/test_definitions/test_reference.py index a1a5dc9..cbcc96c 100644 --- a/test/test_reference.py +++ b/test/test_definitions/test_reference.py @@ -12,31 +12,39 @@ from pydbml.definitions.reference import relation -ParserElement.setDefaultWhitespaceChars(' \t\r') +ParserElement.set_default_whitespace_chars(' \t\r') class TestRelation(TestCase): def test_ok(self) -> None: vals = ['>', '-', '<'] for v in vals: - relation.parseString(v, parseAll=True) + relation.parse_string(v, parseAll=True) def test_nok(self) -> None: val = 'wrong' with self.assertRaises(ParseException): - relation.parseString(val, parseAll=True) + relation.parse_string(val, parseAll=True) class TestInlineRelation(TestCase): def test_ok(self) -> None: val = 'ref: < table.column' - res = ref_inline.parseString(val, parseAll=True) + res = ref_inline.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '<') self.assertEqual(res[0].table2, 'table') self.assertEqual(res[0].col2, 'column') self.assertIsNone(res[0].table1) self.assertIsNone(res[0].col1) + def test_schema(self) -> None: + val1 = 'ref: < table.column' + res1 = ref_inline.parse_string(val1, parseAll=True) + self.assertEqual(res1[0].schema2, 'public') + val2 = 'ref: < myschema.table.column' + res2 = ref_inline.parse_string(val2, parseAll=True) + self.assertEqual(res2[0].schema2, 'myschema') + def test_nok(self) -> None: vals = [ 'ref:\n< table.column', @@ -46,7 +54,7 @@ def test_nok(self) -> None: ] for v in vals: with self.assertRaises(ParseSyntaxException): - ref_inline.parseString(v) + ref_inline.parse_string(v) class TestOnOption(TestCase): @@ -59,23 +67,23 @@ def test_ok(self) -> None: 'set default' ] for v in vals: - on_option.parseString(v, parseAll=True) + on_option.parse_string(v, parseAll=True) def test_nok(self) -> None: val = 'wrong' with self.assertRaises(ParseException): - on_option.parseString(val, parseAll=True) + on_option.parse_string(val, parseAll=True) class TestRefSettings(TestCase): def test_one_setting(self) -> None: val = '[delete: cascade]' - res = ref_settings.parseString(val, parseAll=True) + res = ref_settings.parse_string(val, parseAll=True) self.assertEqual(res[0]['on_delete'], 'cascade') def test_two_settings_multiline(self) -> None: val = '[\ndelete:\ncascade\n,\nupdate:\nrestrict\n]' - res = ref_settings.parseString(val, parseAll=True) + res = ref_settings.parse_string(val, parseAll=True) self.assertEqual(res[0]['on_delete'], 'cascade') self.assertEqual(res[0]['on_update'], 'restrict') @@ -83,16 +91,26 @@ def test_two_settings_multiline(self) -> None: class TestRefShort(TestCase): def test_no_name(self) -> None: val = 'ref: table1.col1 > table2.col2' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') self.assertEqual(res[0].table2, 'table2') self.assertEqual(res[0].col2, 'col2') + def test_schema(self) -> None: + val1 = 'ref: table1.col1 > table2.col2' + res1 = ref_short.parse_string(val1, parseAll=True) + self.assertEqual(res1[0].schema1, 'public') + self.assertEqual(res1[0].schema2, 'public') + val2 = 'ref: myschema1.table1.col1 > myschema2.table2.col2' + res2 = ref_short.parse_string(val2, parseAll=True) + self.assertEqual(res2[0].schema1, 'myschema1') + self.assertEqual(res2[0].schema2, 'myschema2') + def test_name(self) -> None: val = 'ref name: table1.col1 > table2.col2' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -102,7 +120,7 @@ def test_name(self) -> None: def test_composite_with_name(self) -> None: val = 'ref name: table1.(col1 , col2,col3) > table2.(col11 , col21,col31)' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, '(col1 , col2,col3)') @@ -112,7 +130,7 @@ def test_composite_with_name(self) -> None: def test_with_settings(self) -> None: val = 'ref name: table1.col1 > table2.col2 [update: cascade, delete: restrict]' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -125,14 +143,14 @@ def test_with_settings(self) -> None: def test_newline(self) -> None: val = 'ref\nname: table1.col1 > table2.col2' with self.assertRaises(ParseException): - ref_short.parseString(val, parseAll=True) + ref_short.parse_string(val, parseAll=True) val2 = 'ref name: table1.col1\n> table2.col2' with self.assertRaises(ParseSyntaxException): - ref_short.parseString(val2, parseAll=True) + ref_short.parse_string(val2, parseAll=True) def test_comment_above(self) -> None: val = '//comment above\nref name: table1.col1 > table2.col2' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -143,7 +161,7 @@ def test_comment_above(self) -> None: def test_comment_after(self) -> None: val = 'ref name: table1.col1 > table2.col2 //comment after' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -152,7 +170,7 @@ def test_comment_after(self) -> None: self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].comment, 'comment after') val2 = 'ref name: table1.col1 > table2.col2 [update: cascade, delete: restrict] //comment after' - res2 = ref_short.parseString(val2, parseAll=True) + res2 = ref_short.parse_string(val2, parseAll=True) self.assertEqual(res2[0].type, '>') self.assertEqual(res2[0].table1, 'table1') self.assertEqual(res2[0].col1, 'col1') @@ -165,7 +183,7 @@ def test_comment_after(self) -> None: def test_comment_both(self) -> None: val = '//comment above\nref name: table1.col1 > table2.col2 //comment after' - res = ref_short.parseString(val, parseAll=True) + res = ref_short.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -174,7 +192,7 @@ def test_comment_both(self) -> None: self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].comment, 'comment after') val2 = '//comment above\nref name: table1.col1 > table2.col2 [update: cascade, delete: restrict] //comment after' - res2 = ref_short.parseString(val2, parseAll=True) + res2 = ref_short.parse_string(val2, parseAll=True) self.assertEqual(res2[0].type, '>') self.assertEqual(res2[0].table1, 'table1') self.assertEqual(res2[0].col1, 'col1') @@ -189,16 +207,26 @@ def test_comment_both(self) -> None: class TestRefLong(TestCase): def test_no_name(self) -> None: val = 'ref {table1.col1 > table2.col2}' - res = ref_long.parseString(val, parseAll=True) + res = ref_long.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') self.assertEqual(res[0].table2, 'table2') self.assertEqual(res[0].col2, 'col2') + def test_schema(self) -> None: + val1 = 'ref {table1.col1 > table2.col2}' + res1 = ref_long.parse_string(val1, parseAll=True) + self.assertEqual(res1[0].schema1, 'public') + self.assertEqual(res1[0].schema2, 'public') + val2 = 'ref {myschema1.table1.col1 > myschema2.table2.col2}' + res2 = ref_long.parse_string(val2, parseAll=True) + self.assertEqual(res2[0].schema1, 'myschema1') + self.assertEqual(res2[0].schema2, 'myschema2') + def test_name(self) -> None: val = 'ref\nname\n{\ntable1.col1 > table2.col2\n}' - res = ref_long.parseString(val, parseAll=True) + res = ref_long.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -208,7 +236,7 @@ def test_name(self) -> None: def test_with_settings(self) -> None: val = 'ref name {\ntable1.col1 > table2.col2 [update: cascade, delete: restrict]\n}' - res = ref_long.parseString(val, parseAll=True) + res = ref_long.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -220,7 +248,7 @@ def test_with_settings(self) -> None: def test_comment_above(self) -> None: val = '//comment above\nref name {\ntable1.col1 > table2.col2\n}' - res = ref_long.parseString(val, parseAll=True) + res = ref_long.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -231,7 +259,7 @@ def test_comment_above(self) -> None: def test_comment_after(self) -> None: val = 'ref name {\ntable1.col1 > table2.col2 //comment after\n}' - res = ref_long.parseString(val, parseAll=True) + res = ref_long.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -240,7 +268,7 @@ def test_comment_after(self) -> None: self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].comment, 'comment after') val2 = 'ref name {\ntable1.col1 > table2.col2 [update: cascade, delete: restrict] //comment after\n}' - res2 = ref_long.parseString(val2, parseAll=True) + res2 = ref_long.parse_string(val2, parseAll=True) self.assertEqual(res2[0].type, '>') self.assertEqual(res2[0].table1, 'table1') self.assertEqual(res2[0].col1, 'col1') @@ -253,7 +281,7 @@ def test_comment_after(self) -> None: def test_comment_both(self) -> None: val = '//comment above\nref name {\ntable1.col1 > table2.col2 //comment after\n}' - res = ref_long.parseString(val, parseAll=True) + res = ref_long.parse_string(val, parseAll=True) self.assertEqual(res[0].type, '>') self.assertEqual(res[0].table1, 'table1') self.assertEqual(res[0].col1, 'col1') @@ -262,7 +290,7 @@ def test_comment_both(self) -> None: self.assertEqual(res[0].name, 'name') self.assertEqual(res[0].comment, 'comment after') val2 = '//comment above\nref name {\ntable1.col1 > table2.col2 [update: cascade, delete: restrict] //comment after\n}' - res2 = ref_long.parseString(val2, parseAll=True) + res2 = ref_long.parse_string(val2, parseAll=True) self.assertEqual(res2[0].type, '>') self.assertEqual(res2[0].table1, 'table1') self.assertEqual(res2[0].col1, 'col1') diff --git a/test/test_definitions/test_sticky_note.py b/test/test_definitions/test_sticky_note.py new file mode 100644 index 0000000..2870b9d --- /dev/null +++ b/test/test_definitions/test_sticky_note.py @@ -0,0 +1,35 @@ +from unittest import TestCase + +from pyparsing import ParseSyntaxException + +from pydbml.definitions.sticky_note import sticky_note + + +class TestSticky(TestCase): + def test_single_quote(self) -> None: + val = "note mynote {'test note'}" + res = sticky_note.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, 'mynote') + self.assertEqual(res[0].text, 'test note') + + def test_double_quote(self) -> None: + val = 'note \n\nmynote\n\n {\n\n"test note"\n\n}' + res = sticky_note.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, 'mynote') + self.assertEqual(res[0].text, 'test note') + + def test_multiline(self) -> None: + val = "note\nmynote\n{ '''line1\nline2\nline3'''}" + res = sticky_note.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, 'mynote') + self.assertEqual(res[0].text, 'line1\nline2\nline3') + + def test_unclosed_quote(self) -> None: + val = 'note mynote{ "test note}' + with self.assertRaises(ParseSyntaxException): + sticky_note.parse_string(val, parseAll=True) + + def test_not_allowed_multiline(self) -> None: + val = "note mynote { 'line1\nline2\nline3' }" + with self.assertRaises(ParseSyntaxException): + sticky_note.parse_string(val, parseAll=True) diff --git a/test/test_definitions/test_table.py b/test/test_definitions/test_table.py new file mode 100644 index 0000000..cb65ef6 --- /dev/null +++ b/test/test_definitions/test_table.py @@ -0,0 +1,228 @@ +from unittest import TestCase + +from pyparsing import ParseException +from pyparsing import ParseSyntaxException +from pyparsing import ParserElement + +from pydbml.definitions.table import alias, table_with_properties +from pydbml.definitions.table import header_color +from pydbml.definitions.table import table +from pydbml.definitions.table import table_body +from pydbml.definitions.table import table_settings + + +ParserElement.set_default_whitespace_chars(" \t\r") + + +class TestAlias(TestCase): + def test_ok(self) -> None: + val = "as Alias" + alias.parse_string(val, parseAll=True) + + def test_nok(self) -> None: + val = "asalias" + with self.assertRaises(ParseSyntaxException): + alias.parse_string(val, parseAll=True) + + +class TestHeaderColor(TestCase): + def test_oneline(self) -> None: + val = "headercolor: #CCCCCC" + res = header_color.parse_string(val, parseAll=True) + self.assertEqual(res["header_color"], "#CCCCCC") + + def test_multiline(self) -> None: + val = "headercolor:\n\n#E02" + res = header_color.parse_string(val, parseAll=True) + self.assertEqual(res["header_color"], "#E02") + + +class TestTableSettings(TestCase): + def test_one(self) -> None: + val = "[headercolor: #E024DF]" + res = table_settings.parse_string(val, parseAll=True) + self.assertEqual(res[0]["header_color"], "#E024DF") + + def test_both(self) -> None: + val = '[note: "note content", headercolor: #E024DF]' + res = table_settings.parse_string(val, parseAll=True) + self.assertEqual(res[0]["header_color"], "#E024DF") + self.assertIn("note", res[0]) + + +class TestTableBody(TestCase): + def test_one_column(self) -> None: + val = "id integer [pk, increment]\n" + res = table_body.parse_string(val, parseAll=True) + self.assertEqual(len(res["columns"]), 1) + + def test_two_columns(self) -> None: + val = "id integer [pk, increment]\nname string\n" + res = table_body.parse_string(val, parseAll=True) + self.assertEqual(len(res["columns"]), 2) + + def test_columns_indexes(self) -> None: + val = """ +id integer +country varchar [NOT NULL, ref: > countries.country_name] +booking_date date unique pk +indexes { + (id, country) [pk] // composite primary key +}""" + res = table_body.parse_string(val, parseAll=True) + self.assertEqual(len(res["columns"]), 3) + self.assertEqual(len(res["indexes"]), 1) + + def test_columns_indexes_note(self) -> None: + val = """ +id integer +country varchar [NOT NULL, ref: > countries.country_name] +booking_date date unique pk +note: 'mynote' +indexes { + (id, country) [pk] // composite primary key +}""" + res = table_body.parse_string(val, parseAll=True) + self.assertEqual(len(res["columns"]), 3) + self.assertEqual(len(res["indexes"]), 1) + self.assertIsNotNone(res["note"]) + val2 = """ +id integer +country varchar [NOT NULL, ref: > countries.country_name] +booking_date date unique pk +note { + 'mynote' +} +indexes { + (id, country) [pk] // composite primary key +}""" + res2 = table_body.parse_string(val2, parseAll=True) + self.assertEqual(len(res2["columns"]), 3) + self.assertEqual(len(res2["indexes"]), 1) + self.assertIsNotNone(res2["note"]) + + def test_columns_after_indexes_are_allowed(self) -> None: + val = """ +note: 'mynote' +indexes { + (id, country) [pk] // composite primary key +} +id integer""" + res = table_body.parse_string(val, parseAll=True) + self.assertEqual(len(res["columns"]), 1) + self.assertEqual(len(res["indexes"]), 1) + + +class TestTable(TestCase): + def test_simple(self) -> None: + val = "table ids {\nid integer\n}" + res = table.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "ids") + self.assertEqual(len(res[0].columns), 1) + + def test_no_columns(self) -> None: + val = "table ids {\nNote: 'No columns!'\n}" + with self.assertRaises(SyntaxError): + res = table.parse_string(val, parseAll=True) + + def test_with_alias(self) -> None: + val = "table ids as ii {\nid integer\n}" + res = table.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].alias, "ii") + self.assertEqual(len(res[0].columns), 1) + + def test_schema(self) -> None: + val = "table ids as ii {\nid integer\n}" + res = table.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].schema, "public") # default + self.assertEqual(len(res[0].columns), 1) + + val = "table myschema.ids as ii {\nid integer\n}" + res = table.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].schema, "myschema") + + def test_with_settings(self) -> None: + val = 'table ids as ii [headercolor: #ccc, note: "headernote"] {\nid integer\n}' + res = table.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].alias, "ii") + self.assertEqual(res[0].header_color, "#ccc") + self.assertEqual(res[0].note.text, "headernote") + self.assertEqual(len(res[0].columns), 1) + + def test_with_body_note(self) -> None: + val = """ +table ids as ii [ + headercolor: #ccc, + note: "headernote"] +{ + id integer + note: "bodynote" +}""" + res = table.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].alias, "ii") + self.assertEqual(res[0].header_color, "#ccc") + self.assertEqual(res[0].note.text, "bodynote") + self.assertEqual(len(res[0].columns), 1) + + def test_comment_after(self) -> None: + val = """ +// some comment before table +table ids as ii [ + headercolor: #ccc, + note: "headernote"] +{ + id integer + note: "bodynote" +} // some somment after table""" + res = table.parse_string(val, parseAll=True) + self.assertEqual(res[0].comment, "some comment before table") + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].alias, "ii") + self.assertEqual(res[0].header_color, "#ccc") + self.assertEqual(res[0].note.text, "bodynote") + self.assertEqual(len(res[0].columns), 1) + + def test_with_indexes(self) -> None: + val = """ +table ids as ii [ + headercolor: #ccc, + note: "headernote"] +{ + id integer + country varchar + note: "bodynote" + indexes { + (id, country) [pk] // composite primary key + } +}""" + res = table.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "ids") + self.assertEqual(res[0].alias, "ii") + self.assertEqual(res[0].header_color, "#ccc") + self.assertEqual(res[0].note.text, "bodynote") + self.assertEqual(len(res[0].columns), 2) + self.assertEqual(len(res[0].indexes), 1) + + +def test_properties() -> None: + val = """ +table ids as ii [ + headercolor: #ccc, + note: "headernote"] +{ + id integer + country varchar + note: "bodynote" + foo: 'bar' + baz: '''123''' + indexes { + (id, country) [pk] // composite primary key + } +}""" + res = table_with_properties.parse_string(val, parseAll=True) + assert res[0].properties == {"foo": "bar", "baz": "123"} diff --git a/test/test_definitions/test_table_group.py b/test/test_definitions/test_table_group.py new file mode 100644 index 0000000..b119d1d --- /dev/null +++ b/test/test_definitions/test_table_group.py @@ -0,0 +1,74 @@ +from textwrap import dedent +from unittest import TestCase + +from pyparsing import ParserElement + +from pydbml.definitions.table_group import table_group + + +ParserElement.set_default_whitespace_chars(" \t\r") + + +class TestTableGroup(TestCase): + def test_empty(self) -> None: + val = "TableGroup name {}" + res = table_group.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "name") + + def test_fields(self) -> None: + val = "TableGroup name {table1 table2}" + res = table_group.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) + + def test_comment(self) -> None: + val = "//comment before\nTableGroup name\n{\ntable1\ntable2\n}" + res = table_group.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) + self.assertEqual(res[0].comment, "comment before") + + def test_note_settings(self) -> None: + val = "TableGroup name [note: 'My note'] \n{\ntable1\ntable2\n}" + res = table_group.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) + self.assertEqual(res[0].note.text, "My note") + + def test_color(self) -> None: + val = "TableGroup name [color: #FFF] \n{\ntable1\ntable2\n}" + res = table_group.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) + self.assertEqual(res[0].color, "#FFF") + + def test_all_settings(self) -> None: + val = "TableGroup name [color: #FFF, note: 'My note'] \n{\ntable1\ntable2\n}" + res = table_group.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) + self.assertEqual(res[0].color, "#FFF") + self.assertEqual(res[0].note.text, "My note") + + def test_note_body(self) -> None: + val = dedent("""\ + TableGroup name { + table1 + Note: ''' + Note line1 + Note line2 + ''' + table2 + } + """) + res = table_group.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) + self.assertIn("Note line1\n", res[0].note.text,) + + def test_note_settings_overriden_by_note_body(self) -> None: + val = "TableGroup name [note: 'Settings note'] \n{\ntable1\nnote: 'Body note'\ntable2\n}" + res = table_group.parse_string(val, parseAll=True) + self.assertEqual(res[0].name, "name") + self.assertEqual(res[0].items, ["table1", "table2"]) + self.assertEqual(res[0].note.text, "Body note") diff --git a/test/test_docs.py b/test/test_docs.py index f3fae25..c5025bd 100644 --- a/test/test_docs.py +++ b/test/test_docs.py @@ -9,11 +9,10 @@ from unittest import TestCase from pydbml import PyDBML -from pydbml.exceptions import ColumnNotFoundError -from pydbml.exceptions import TableNotFoundError - +from pydbml.classes import Expression, Enum TEST_DOCS_PATH = Path(os.path.abspath(__file__)).parent / 'test_data/docs' +TestCase.maxDiff = None class TestDocs(TestCase): @@ -25,7 +24,7 @@ def test_example(self) -> None: self.assertEqual(len(results.refs), 1) ref = results.refs[0] - self.assertEqual((posts, users), (ref.table1, ref.table2)) + self.assertEqual((posts, users), (ref.col1[0].table, ref.col2[0].table)) def test_project(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'project.dbml') @@ -52,7 +51,7 @@ def test_table_alias(self) -> None: self.assertEqual(len(results.refs), 1) ref = results.refs[0] - self.assertEqual((u, posts), (ref.table1, ref.table2)) + self.assertEqual((u, posts), (ref.col1[0].table, ref.col2[0].table)) def test_table_notes(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'table_notes.dbml') @@ -94,7 +93,8 @@ def test_default_value(self) -> None: self.assertEqual([c.name for c in table.columns], ['id', 'username', 'full_name', 'gender', 'created_at', 'rating']) *_, gender, created_at, rating = table.columns self.assertEqual(gender.default, 'm') - self.assertEqual(created_at.default, '(now())') + self.assertIsInstance(created_at.default, Expression) + self.assertEqual(created_at.default.text, 'now()') self.assertEqual(rating.default, 10) def test_index_definition(self) -> None: @@ -123,11 +123,17 @@ def test_index_definition(self) -> None: self.assertEqual(ix[4].subjects, [table['booking_date']]) self.assertEqual(ix[4].type, 'hash') - self.assertEqual(ix[5].subjects, ['(id*2)']) + self.assertEqual(len(ix[5].subjects), 1) + self.assertIsInstance(ix[5].subjects[0], Expression) + self.assertEqual(ix[5].subjects[0].text, 'id*2') - self.assertEqual(ix[6].subjects, ['(id*3)', '(getdate())']) + self.assertEqual(len(ix[6].subjects), 2) + self.assertIsInstance(ix[6].subjects[0], Expression) + self.assertIsInstance(ix[6].subjects[1], Expression) + self.assertEqual(ix[6].subjects[0].text, 'id*3') + self.assertEqual(ix[6].subjects[1].text, 'getdate()') - self.assertEqual(ix[7].subjects, ['(id*3)', table['id']]) + self.assertEqual(ix[7].subjects, [Expression('id*3'), table['id']]) def test_relationships(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'relationships_1.dbml') @@ -135,12 +141,12 @@ def test_relationships(self) -> None: rf = results.refs - self.assertEqual(rf[0].table1, posts) - self.assertEqual(rf[0].table2, users) + self.assertEqual(rf[0].col1[0].table, posts) + self.assertEqual(rf[0].col2[0].table, users) self.assertEqual(rf[0].type, '>') - self.assertEqual(rf[1].table1, reviews) - self.assertEqual(rf[1].table2, users) + self.assertEqual(rf[1].col1[0].table, reviews) + self.assertEqual(rf[1].col2[0].table, users) self.assertEqual(rf[1].type, '>') results = PyDBML.parse_file(TEST_DOCS_PATH / 'relationships_2.dbml') @@ -148,12 +154,12 @@ def test_relationships(self) -> None: rf = results.refs - self.assertEqual(rf[0].table1, users) - self.assertEqual(rf[0].table2, posts) + self.assertEqual(rf[0].col1[0].table, users) + self.assertEqual(rf[0].col2[0].table, posts) self.assertEqual(rf[0].type, '<') - self.assertEqual(rf[1].table1, users) - self.assertEqual(rf[1].table2, reviews) + self.assertEqual(rf[1].col1[0].table, users) + self.assertEqual(rf[1].col2[0].table, reviews) self.assertEqual(rf[1].type, '<') def test_relationships_composite(self) -> None: @@ -164,8 +170,8 @@ def test_relationships_composite(self) -> None: self.assertEqual(len(rf), 1) - self.assertEqual(rf[0].table1, merchant_periods) - self.assertEqual(rf[0].table2, merchants) + self.assertEqual(rf[0].col1[0].table, merchant_periods) + self.assertEqual(rf[0].col2[0].table, merchants) self.assertEqual(rf[0].type, '>') self.assertEqual( rf[0].col1, @@ -190,8 +196,8 @@ def test_relationship_settings(self) -> None: self.assertEqual(len(rf), 1) - self.assertEqual(rf[0].table1, merchant_periods) - self.assertEqual(rf[0].table2, merchants) + self.assertEqual(rf[0].col1[0].table, merchant_periods) + self.assertEqual(rf[0].col2[0].table, merchants) self.assertEqual(rf[0].type, '>') self.assertEqual(rf[0].on_delete, 'cascade') self.assertEqual(rf[0].on_update, 'no action') @@ -199,7 +205,7 @@ def test_relationship_settings(self) -> None: def test_note_definition(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'note_definition.dbml') self.assertEqual(len(results.tables), 1) - users = results['users'] + users = results['public.users'] self.assertEqual(users.note.text, 'This is a note of this table') def test_project_notes(self) -> None: @@ -207,11 +213,11 @@ def test_project_notes(self) -> None: project = results.project self.assertEqual(project.name, 'DBML') - self.assertTrue(project.note.text.startswith('\n # DBML - Database Markup Language\n DBML')) + self.assertTrue(project.note.text.startswith('# DBML - Database Markup Language\nDBML')) def test_column_notes(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'column_notes.dbml') - users = results['users'] + users = results['public.users'] self.assertEqual(users.note.text, 'Stores user data') self.assertEqual(users['column_name'].note.text, 'replace text here') @@ -219,15 +225,19 @@ def test_column_notes(self) -> None: def test_enum_definition(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'enum_definition.dbml') - jobs = results['jobs'] - jobs['status'].type == 'job_status' + jobs = results['public.jobs'] + self.assertIsInstance(jobs['status'].type, Enum) + self.assertIsInstance(jobs['grade'].type, Enum) - self.assertEqual(len(results.enums), 1) - js = results.enums[0] + self.assertEqual(len(results.enums), 2) + js, g = results.enums self.assertEqual(js.name, 'job_status') self.assertEqual([ei.name for ei in js.items], ['created', 'running', 'done', 'failure']) + self.assertEqual(g.name, 'grade') + self.assertEqual([ei.name for ei in g.items], ['A+', 'A', 'A-', 'Not Yet Set']) + def test_table_group(self) -> None: results = PyDBML.parse_file(TEST_DOCS_PATH / 'table_group.dbml') @@ -242,3 +252,15 @@ def test_table_group(self) -> None: self.assertEqual(tg1.items, [tb1, tb2, tb3]) self.assertEqual(tg2.name, 'e_commerce1') self.assertEqual(tg2.items, [merchants, countries]) + + def test_sticky_notes(self) -> None: + results = PyDBML.parse_file(TEST_DOCS_PATH / 'sticky_notes.dbml') + + self.assertEqual(len(results.sticky_notes), 2) + + sn1, sn2 = results.sticky_notes + + self.assertEqual(sn1.name, 'single_line_note') + self.assertEqual(sn1.text, 'This is a single line note') + self.assertEqual(sn2.name, 'multiple_lines_note') + self.assertEqual(sn2.text, '''This is a multiple lines note\nThis string can spans over multiple lines.''') diff --git a/test/test_doctest.py b/test/test_doctest.py new file mode 100644 index 0000000..9354d1e --- /dev/null +++ b/test/test_doctest.py @@ -0,0 +1,28 @@ +import doctest + +from pydbml import database +from pydbml._classes import column +from pydbml._classes import enum +from pydbml._classes import expression +from pydbml._classes import index +from pydbml._classes import note +from pydbml._classes import project +from pydbml._classes import reference +from pydbml._classes import table +from pydbml._classes import table_group +from pydbml.parser import parser + + +def load_tests(loader, tests, ignore): + tests.addTests(doctest.DocTestSuite(column)) + tests.addTests(doctest.DocTestSuite(enum)) + tests.addTests(doctest.DocTestSuite(expression)) + tests.addTests(doctest.DocTestSuite(index)) + tests.addTests(doctest.DocTestSuite(project)) + tests.addTests(doctest.DocTestSuite(note)) + tests.addTests(doctest.DocTestSuite(reference)) + tests.addTests(doctest.DocTestSuite(database)) + tests.addTests(doctest.DocTestSuite(table)) + tests.addTests(doctest.DocTestSuite(table_group)) + tests.addTests(doctest.DocTestSuite(parser)) + return tests diff --git a/test/test_editing.py b/test/test_editing.py new file mode 100644 index 0000000..05b9bc9 --- /dev/null +++ b/test/test_editing.py @@ -0,0 +1,90 @@ +import os + +from pathlib import Path +from unittest import TestCase + +from pyparsing import ParserElement + +from pydbml import PyDBML + + +ParserElement.set_default_whitespace_chars(' \t\r') + + +TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' + + +class EditingTestCase(TestCase): + def setUp(self): + self.dbml = PyDBML(TEST_DATA_PATH / 'editing.dbml') + + +class TestEditTable(EditingTestCase): + def test_name(self) -> None: + products = self.dbml['public.products'] + products.name = 'changed_products' + self.assertIn('CREATE TABLE "changed_products"', products.sql) + self.assertIn('CREATE INDEX "product_status" ON "changed_products"', products.sql) + self.assertIn('Table "changed_products"', products.dbml) + + ref = self.dbml.refs[0] + self.assertIn('ALTER TABLE "changed_products"', ref.sql) + self.assertIn('"changed_products"."merchant_id"', ref.dbml) + + index = products.indexes[0] + self.assertIn('ON "changed_products"', index.sql) + + def test_alias(self) -> None: + products = self.dbml['public.products'] + products.alias = 'new_alias' + + self.assertIn('as "new_alias"', products.dbml) + + +class TestColumn(EditingTestCase): + def test_name(self) -> None: + products = self.dbml['public.products'] + col = products['name'] + col.name = 'new_name' + self.assertEqual(col.sql, '"new_name" varchar') + self.assertEqual(col.dbml, '"new_name" varchar') + self.assertIn('"new_name" varchar', products.sql) + self.assertIn('"new_name" varchar', products.dbml) + + self.assertEqual(col, products[col.name]) + + def test_name_index(self) -> None: + products = self.dbml['public.products'] + col = products['status'] + col.name = 'changed_status' + self.assertIn('"changed_status"', products.indexes[0].sql) + self.assertIn('changed_status', products.indexes[0].dbml) + self.assertIn( + 'CREATE INDEX "product_status" ON "products" ("merchant_id", "changed_status");', + products.sql + ) + self.assertIn( + "(merchant_id, changed_status) [name: 'product_status']", + products.dbml + ) + + def test_name_ref(self) -> None: + products = self.dbml['public.products'] + col = products['merchant_id'] + col.name = 'changed_merchant_id' + merchants = self.dbml['public.merchants'] + table_ref = merchants.get_refs()[0] + self.assertIn('FOREIGN KEY ("changed_merchant_id")', table_ref.sql) + + +class TestEnum(EditingTestCase): + def test_enum_name(self): + products = self.dbml['public.products'] + enum = self.dbml.enums[0] + enum.name = 'changed product status' + self.assertIn('CREATE TYPE "changed product status"', enum.sql) + self.assertIn('Enum "changed product status"', enum.dbml) + + col = products['status'] + self.assertEqual(col.sql, '"status" "changed product status"') + self.assertEqual(col.dbml, '"status" "changed product status"') diff --git a/test/test_integration.py b/test/test_integration.py new file mode 100644 index 0000000..7b134bb --- /dev/null +++ b/test/test_integration.py @@ -0,0 +1,122 @@ +import os + +from pathlib import Path +from unittest import TestCase + +from pydbml import PyDBML +from pydbml.classes import Column +from pydbml.classes import Enum +from pydbml.classes import EnumItem +from pydbml.classes import Expression +from pydbml.classes import Index +from pydbml.classes import Note +from pydbml.classes import Project +from pydbml.classes import Reference +from pydbml.classes import Table +from pydbml.classes import TableGroup +from pydbml.database import Database + + +TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' + + +class TestGenerateDBML(TestCase): + def create_database(self) -> Database: + database = Database() + emp_level = Enum( + 'level', + [ + EnumItem('junior'), + EnumItem('middle'), + EnumItem('senior'), + ] + ) + database.add(emp_level) + + t1 = Table('Employees', alias='emp') + c11 = Column('id', 'integer', pk=True, autoinc=True) + c12 = Column('name', 'varchar', note=Note('Full employee name')) + c13 = Column('age', 'number', default=0) + c14 = Column('level', 'level') + c15 = Column('favorite_book_id', 'integer') + t1.add_column(c11) + t1.add_column(c12) + t1.add_column(c13) + t1.add_column(c14) + t1.add_column(c15) + database.add(t1) + + t2 = Table('books') + c21 = Column('id', 'integer', pk=True, autoinc=True) + c22 = Column('title', 'varchar') + c23 = Column('author', 'varchar') + c24 = Column('country_id', 'integer') + t2.add_column(c21) + t2.add_column(c22) + t2.add_column(c23) + t2.add_column(c24) + database.add(t2) + + t3 = Table('countries') + c31 = Column('id', 'integer', pk=True, autoinc=True) + c32 = Column('name', 'varchar2', unique=True) + t3.add_column(c31) + t3.add_column(c32) + i31 = Index([c32], unique=True) + t3.add_index(i31) + i32 = Index([Expression('UPPER(name)')]) + t3.add_index(i32) + database.add(t3) + + ref1 = Reference('>', c15, c21) + database.add(ref1) + + ref2 = Reference('<', c31, c24, name='Country Reference', inline=True) + database.add(ref2) + + tg = TableGroup('Unanimate', [t2, t3]) + database.add(tg) + + p = Project('my project', {'author': 'me', 'reason': 'testing'}) + database.add(p) + return database + + def test_generate_dbml(self) -> None: + database = self.create_database() + with open(TEST_DATA_PATH / 'integration1.dbml') as f: + expected = f.read() + self.assertEqual(database.dbml, expected) + + def test_generate_sql(self) -> None: + database = self.create_database() + with open(TEST_DATA_PATH / 'integration1.sql') as f: + expected = f.read() + self.assertEqual(database.sql, expected) + + def test_parser(self): + source_path = TEST_DATA_PATH / 'integration1.dbml' + with self.assertRaises(TypeError): + PyDBML(2) + res1 = PyDBML(source_path) + self.assertIsInstance(res1, Database) + with open(source_path) as f: + res2 = PyDBML(f) + self.assertIsInstance(res2, Database) + with open(source_path) as f: + source = f.read() + res3 = PyDBML(source) + self.assertIsInstance(res3, Database) + res4 = PyDBML('\ufeff' + source) + self.assertIsInstance(res4, Database) + + pydbml = PyDBML() + self.assertIsInstance(pydbml, PyDBML) + res5 = pydbml.parse(source) + self.assertIsInstance(res5, Database) + res6 = PyDBML.parse('\ufeff' + source) + self.assertIsInstance(res6, Database) + res7 = PyDBML.parse_file(str(source_path)) + self.assertIsInstance(res7, Database) + with open(source_path) as f: + res8 = PyDBML.parse_file(f) + self.assertIsInstance(res8, Database) diff --git a/test/test_parser.py b/test/test_parser.py index f6b76f7..745188e 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -6,6 +6,7 @@ from pydbml import PyDBML from pydbml.exceptions import ColumnNotFoundError from pydbml.exceptions import TableNotFoundError +from pydbml.parser.parser import PyDBMLParser TEST_DATA_PATH = Path(os.path.abspath(__file__)).parent / 'test_data' @@ -17,73 +18,102 @@ def setUp(self): def test_table_refs(self) -> None: p = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') - r = p['order_items'].refs - self.assertEqual(r[0].col[0].name, 'order_id') - self.assertEqual(r[0].ref_table.name, 'orders') - self.assertEqual(r[0].ref_col[0].name, 'id') - r = p['products'].refs - self.assertEqual(r[0].col[0].name, 'merchant_id') - self.assertEqual(r[0].ref_table.name, 'merchants') - self.assertEqual(r[0].ref_col[0].name, 'id') - r = p['users'].refs - self.assertEqual(r[0].col[0].name, 'country_code') - self.assertEqual(r[0].ref_table.name, 'countries') - self.assertEqual(r[0].ref_col[0].name, 'code') + r = p['public.orders'].get_refs() + self.assertEqual(r[0].col2[0].name, 'order_id') + self.assertEqual(r[0].col1[0].table.name, 'orders') + self.assertEqual(r[0].col1[0].name, 'id') + r = p['public.products'].get_refs() + self.assertEqual(r[1].col1[0].name, 'merchant_id') + self.assertEqual(r[1].col2[0].table.name, 'merchants') + self.assertEqual(r[1].col2[0].name, 'id') + r = p['public.users'].get_refs() + self.assertEqual(r[0].col1[0].name, 'country_code') + self.assertEqual(r[0].col2[0].table.name, 'countries') + self.assertEqual(r[0].col2[0].name, 'code') def test_refs(self) -> None: p = PyDBML.parse_file(TEST_DATA_PATH / 'general.dbml') r = p.refs - self.assertEqual(r[0].table1.name, 'orders') + self.assertEqual(r[0].col1[0].table.name, 'orders') self.assertEqual(r[0].col1[0].name, 'id') - self.assertEqual(r[0].table2.name, 'order_items') + self.assertEqual(r[0].col2[0].table.name, 'order_items') self.assertEqual(r[0].col2[0].name, 'order_id') - self.assertEqual(r[2].table1.name, 'users') + self.assertEqual(r[2].col1[0].table.name, 'users') self.assertEqual(r[2].col1[0].name, 'country_code') - self.assertEqual(r[2].table2.name, 'countries') + self.assertEqual(r[2].col2[0].table.name, 'countries') self.assertEqual(r[2].col2[0].name, 'code') - self.assertEqual(r[4].table1.name, 'products') + self.assertEqual(r[4].col1[0].table.name, 'products') self.assertEqual(r[4].col1[0].name, 'merchant_id') - self.assertEqual(r[4].table2.name, 'merchants') + self.assertEqual(r[4].col2[0].table.name, 'merchants') self.assertEqual(r[4].col2[0].name, 'id') + def test_inline_refs_schema(self) -> None: + # Thanks @jens-koster for this example + source = ''' +Table core.pk_tbl { + pk_col varchar [pk] +} +Table core.fk_tbl { + fk_col varchar [ref: > core.pk_tbl.pk_col] +} +''' + p = PyDBMLParser(source) + p.parse() + r = p.refs + pk_tbl = p.tables[0] + fk_tble = p.tables[1] + ref = p.refs[0] + self.assertEqual(ref.table1, fk_tble.name) + self.assertEqual(ref.table2, pk_tbl.name) + self.assertEqual(ref.schema1, fk_tble.schema) + self.assertEqual(ref.schema2, pk_tbl.schema) + class TestRefs(TestCase): def test_reference_aliases(self): results = PyDBML.parse_file(TEST_DATA_PATH / 'relationships_aliases.dbml') - posts, reviews, users = results['posts'], results['reviews'], results['users'] - posts2, reviews2, users2 = results['posts2'], results['reviews2'], results['users2'] + posts, reviews, users = results['public.posts'], results['public.reviews'], results['public.users'] + posts2, reviews2, users2 = results['public.posts2'], results['public.reviews2'], results['public.users2'] rs = results.refs - self.assertEqual(rs[0].table1, users) - self.assertEqual(rs[0].table2, posts) - self.assertEqual(rs[1].table1, users) - self.assertEqual(rs[1].table2, reviews) + self.assertEqual(rs[0].col1[0].table, users) + self.assertEqual(rs[0].col2[0].table, posts) + self.assertEqual(rs[1].col1[0].table, users) + self.assertEqual(rs[1].col2[0].table, reviews) - self.assertEqual(rs[2].table1, posts2) - self.assertEqual(rs[2].table2, users2) - self.assertEqual(rs[3].table1, reviews2) - self.assertEqual(rs[3].table2, users2) + self.assertEqual(rs[2].col1[0].table, posts2) + self.assertEqual(rs[2].col2[0].table, users2) + self.assertEqual(rs[3].col1[0].table, reviews2) + self.assertEqual(rs[3].col2[0].table, users2) def test_composite_references(self): results = PyDBML.parse_file(TEST_DATA_PATH / 'relationships_composite.dbml') self.assertEqual(len(results.tables), 4) - posts, reviews = results['posts'], results['reviews'] - posts2, reviews2 = results['posts2'], results['reviews2'] + posts, reviews = results['public.posts'], results['public.reviews'] + posts2, reviews2 = results['public.posts2'], results['public.reviews2'] rs = results.refs self.assertEqual(len(rs), 2) - self.assertEqual(rs[0].table1, posts) + self.assertEqual(rs[0].col1[0].table, posts) self.assertEqual(rs[0].col1, [posts['id'], posts['tag']]) - self.assertEqual(rs[0].table2, reviews) + self.assertEqual(rs[0].col2[0].table, reviews) self.assertEqual(rs[0].col2, [reviews['post_id'], reviews['tag']]) - self.assertEqual(rs[1].table1, posts2) + self.assertEqual(rs[1].col1[0].table, posts2) self.assertEqual(rs[1].col1, [posts2['id'], posts2['tag']]) - self.assertEqual(rs[1].table2, reviews2) + self.assertEqual(rs[1].col2[0].table, reviews2) self.assertEqual(rs[1].col2, [reviews2['post_id'], reviews2['tag']]) +class TestDBMLReferenceDef(TestCase): + def test_dbml_reference_def(self): + results = PyDBML.parse_file(TEST_DATA_PATH / 'dbml_schema_def.dbml') + self.assertEqual(len(results.tables), 5) + self.assertEqual(len(results.table_groups), 1) + self.assertEqual(len(results.enums), 3) + + class TestFaulty(TestCase): def test_bad_reference(self) -> None: with self.assertRaises(TableNotFoundError): @@ -94,3 +124,64 @@ def test_bad_reference(self) -> None: def test_bad_index(self) -> None: with self.assertRaises(ColumnNotFoundError): PyDBML(TEST_DATA_PATH / 'wrong_index.dbml') + + +class TestPyDBMLParser(TestCase): + def test_edge(self) -> None: + p = PyDBMLParser('') + with self.assertRaises(RuntimeError): + p.locate_table('myschema', 'test') + with self.assertRaises(RuntimeError): + p.parse_blueprint(1, 1, [1]) + + +class TestNotesIdempotent(TestCase): + def test_note_is_idempotent(self): + dbml_source = """ +Table test { + id integer + Note { + ''' + Indented note which is actually a Markdown formatted string: + + - List item 1 + - Another list item + + ```python + def test(): + print('Hello world!') + return 1 + ``` + ''' + } +} +""" + source_text = \ +"""Indented note which is actually a Markdown formatted string: + +- List item 1 +- Another list item + +```python +def test(): + print('Hello world!') + return 1 +```""" + p = PyDBML(dbml_source) + note = p.tables[0].note + self.assertEqual(source_text, note.text) + + p_mod = p + for _ in range(10): + p_mod = PyDBML(p_mod.dbml) + note2 = p_mod.tables[0].note + self.assertEqual(source_text, note2.text) + + +def test_repr_pydbml() -> None: + assert repr(PyDBML()) == "" + + +def test_repr_pydbml_parser() -> None: + assert repr(PyDBMLParser('')) == "" + diff --git a/test/test_renderer/__init__.py b/test/test_renderer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_renderer/test_base.py b/test/test_renderer/test_base.py new file mode 100644 index 0000000..9e9f80c --- /dev/null +++ b/test/test_renderer/test_base.py @@ -0,0 +1,40 @@ +from pydbml.renderer.base import BaseRenderer + + +class SampleRenderer(BaseRenderer): + model_renderers = {} + + +def test_renderer_for() -> None: + @SampleRenderer.renderer_for(str) + def render_str(model): + return 'str' + + assert len(SampleRenderer.model_renderers) == 1 + assert str in SampleRenderer.model_renderers + assert SampleRenderer.model_renderers[str] is render_str + + +class TestRender: + @staticmethod + def test_render() -> None: + @SampleRenderer.renderer_for(str) + def render_str(model): + return 'str' + + assert SampleRenderer.render('') == 'str' + + @staticmethod + def test_render_not_supported() -> None: + assert SampleRenderer.render(1) == '' + + @staticmethod + def test_unsupported_renderer_override() -> None: + def unsupported_renderer(model): + return 'unsupported' + + class SampleRenderer2(BaseRenderer): + model_renderers = {} + _unsupported_renderer = unsupported_renderer + + assert SampleRenderer2.render(1) == 'unsupported' diff --git a/test/test_renderer/test_dbml/__init__.py b/test/test_renderer/test_dbml/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_renderer/test_dbml/test_column.py b/test/test_renderer/test_dbml/test_column.py new file mode 100644 index 0000000..b2cf3b8 --- /dev/null +++ b/test/test_renderer/test_dbml/test_column.py @@ -0,0 +1,139 @@ +from enum import Enum +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from pydbml.classes import Column, Note +from pydbml.renderer.dbml.default.column import ( + default_to_str, + render_options, + render_column, +) + + +@pytest.mark.parametrize( + "input,expected", + [ + ("test", "'test'"), + (1, "1"), + (1.0, "1.0"), + (True, "True"), + ("False", "false"), + ("null", "null"), + ("b'0'", "'b\\'0\\''"), + ], +) +def test_default_to_str(input: Any, expected: str) -> None: + assert default_to_str(input) == expected + + +class TestRenderOptions: + @staticmethod + def test_refs(simple_column: Column) -> None: + simple_column.get_refs = Mock( + return_value=[ + Mock(dbml="ref1", inline=True), + Mock(dbml="ref2", inline=False), + Mock(dbml="ref3", inline=True), + ] + ) + assert render_options(simple_column) == " [ref1, ref3]" + + @staticmethod + def test_pk(simple_column_with_table: Column) -> None: + simple_column_with_table.pk = True + assert render_options(simple_column_with_table) == " [pk]" + + @staticmethod + def test_autoinc(simple_column_with_table: Column) -> None: + simple_column_with_table.autoinc = True + assert render_options(simple_column_with_table) == " [increment]" + + @staticmethod + def test_default(simple_column_with_table: Column) -> None: + simple_column_with_table.default = "6" + assert render_options(simple_column_with_table) == " [default: '6']" + + @staticmethod + def test_unique(simple_column_with_table: Column) -> None: + simple_column_with_table.unique = True + assert render_options(simple_column_with_table) == " [unique]" + + @staticmethod + def test_not_null(simple_column_with_table: Column) -> None: + simple_column_with_table.not_null = True + assert render_options(simple_column_with_table) == " [not null]" + + @staticmethod + def test_note(simple_column_with_table: Column) -> None: + simple_column_with_table.note = Note("note") + with patch( + "pydbml.renderer.dbml.default.column.note_option_to_dbml", + Mock(return_value="note"), + ): + assert render_options(simple_column_with_table) == " [note]" + + @staticmethod + def test_no_options(simple_column_with_table: Column) -> None: + assert render_options(simple_column_with_table) == "" + + @staticmethod + def test_properties(simple_column_with_table: Column) -> None: + simple_column_with_table.properties = {"key": "value"} + simple_column_with_table.table.database.allow_properties = True + assert render_options(simple_column_with_table) == " [key: 'value']" + + @staticmethod + def test_properties_not_allowed(simple_column_with_table: Column) -> None: + simple_column_with_table.properties = {"key": "value"} + simple_column_with_table.table.database.allow_properties = False + assert render_options(simple_column_with_table) == "" + + @staticmethod + def test_all_options(complex_column: Column) -> None: + complex_column.table = Mock(database=Mock(allow_properties=True)) + complex_column.get_refs = Mock( + return_value=[ + Mock(dbml="ref1", inline=True), + Mock(dbml="ref2", inline=False), + Mock(dbml="ref3", inline=True), + ] + ) + complex_column.default = "null" + + expected = ( + " [ref1, ref3, pk, increment, default: null, unique, not null, note, foo: " + "'bar', baz: '''\n" + "qux\n" + "qux''']" + ) + + with patch( + "pydbml.renderer.dbml.default.column.note_option_to_dbml", + Mock(return_value="note"), + ): + assert render_options(complex_column) == expected + + +class TestRenderColumn: + @staticmethod + def test_comment(simple_column_with_table: Column) -> None: + simple_column_with_table.comment = "Simple comment" + expected = '// Simple comment\n"id" integer' + assert render_column(simple_column_with_table) == expected + + @staticmethod + def test_enum(simple_column_with_table: Column, enum1: Enum) -> None: + simple_column_with_table.type = enum1 + expected = '"id" "product status"' + assert render_column(simple_column_with_table) == expected + + @staticmethod + def test_complex(complex_column_with_table: Column) -> None: + expected = ( + "// This is a counter column\n" + '"counter" "product status" [pk, increment, unique, not null, note: \'This is ' + "a note for the column']" + ) + assert render_column(complex_column_with_table) == expected diff --git a/test/test_renderer/test_dbml/test_enum.py b/test/test_renderer/test_dbml/test_enum.py new file mode 100644 index 0000000..64d2daf --- /dev/null +++ b/test/test_renderer/test_dbml/test_enum.py @@ -0,0 +1,36 @@ +from pydbml.classes import Enum, EnumItem, Note +from pydbml.renderer.dbml.default.enum import ( + render_enum_item, + render_enum, +) + + +class TestRenderEnumItem: + @staticmethod + def test_simple(enum_item1: EnumItem) -> None: + assert render_enum_item(enum_item1) == '"en-US"' + + @staticmethod + def test_comment(enum_item1: EnumItem) -> None: + enum_item1.comment = "comment" + expected = '// comment\n"en-US"' + assert render_enum_item(enum_item1) == expected + + @staticmethod + def test_note(enum_item1: EnumItem) -> None: + enum_item1.note = Note("Enum item note") + expected = "\"en-US\" [note: 'Enum item note']" + assert render_enum_item(enum_item1) == expected + + +class TestEnum: + @staticmethod + def test_simple(enum1: Enum) -> None: + expected = 'Enum "product status" {\n "production"\n "development"\n}' + assert render_enum(enum1) == expected + + @staticmethod + def test_comment(enum1: Enum) -> None: + enum1.comment = "comment" + expected = '// comment\nEnum "product status" {\n "production"\n "development"\n}' + assert render_enum(enum1) == expected diff --git a/test/test_renderer/test_dbml/test_expression.py b/test/test_renderer/test_dbml/test_expression.py new file mode 100644 index 0000000..dc189b0 --- /dev/null +++ b/test/test_renderer/test_dbml/test_expression.py @@ -0,0 +1,6 @@ +from pydbml.classes import Expression +from pydbml.renderer.dbml.default import render_expression + + +def test_render_expression(expression1: Expression) -> None: + assert render_expression(expression1) == "`SUM(amount)`" diff --git a/test/test_renderer/test_dbml/test_index.py b/test/test_renderer/test_dbml/test_index.py new file mode 100644 index 0000000..06ff29b --- /dev/null +++ b/test/test_renderer/test_dbml/test_index.py @@ -0,0 +1,92 @@ +from unittest.mock import patch, Mock + +from pydbml.classes import Index, Expression, Note +from pydbml.renderer.dbml.default.index import render_subjects, render_options, render_index + + +class TestRenderSubjects: + @staticmethod + def test_column(index1: Index) -> None: + assert render_subjects(index1.subjects) == "name" + + @staticmethod + def test_expression(index1: Index) -> None: + index1.subjects = [Expression("SUM(amount)")] + assert render_subjects(index1.subjects) == "`SUM(amount)`" + + @staticmethod + def test_string(index1: Index) -> None: + index1.subjects = ["name"] + assert render_subjects(index1.subjects) == "name" + + @staticmethod + def test_multiple(index1: Index) -> None: + index1.subjects.append(Expression("SUM(amount)")) + index1.subjects.append("name") + assert render_subjects(index1.subjects) == "(name, `SUM(amount)`, name)" + + +class TestRenderOptions: + @staticmethod + def test_name(index1: Index) -> None: + index1.name = "index_name" + assert render_options(index1) == " [name: 'index_name']" + + @staticmethod + def test_pk(index1: Index) -> None: + index1.pk = True + assert render_options(index1) == " [pk]" + + @staticmethod + def test_unique(index1: Index) -> None: + index1.unique = True + assert render_options(index1) == " [unique]" + + @staticmethod + def test_type(index1: Index) -> None: + index1.type = "hash" + assert render_options(index1) == " [type: hash]" + + @staticmethod + def test_note(index1: Index) -> None: + index1.note = Note("note") + with patch( + "pydbml.renderer.dbml.default.index.note_option_to_dbml", + Mock(return_value="note"), + ): + assert render_options(index1) == " [note]" + + @staticmethod + def test_no_options(index1: Index) -> None: + assert render_options(index1) == "" + + @staticmethod + def test_all_options(index1: Index) -> None: + index1.name = "index_name" + index1.pk = True + index1.unique = True + index1.type = "hash" + index1.note = Note("note") + with patch( + "pydbml.renderer.dbml.default.index.note_option_to_dbml", + Mock(return_value="note"), + ): + assert ( + render_options(index1) + == " [name: 'index_name', pk, unique, type: hash, note]" + ) + + +def test_render_index(index1: Index) -> None: + index1.comment = "Index comment" + with patch( + "pydbml.renderer.dbml.default.index.render_subjects", + Mock(return_value="subjects "), + ) as render_subjects_mock: + with patch( + "pydbml.renderer.dbml.default.index.render_options", + Mock(return_value="options"), + ) as render_options_mock: + assert render_index(index1) == '// Index comment\nsubjects options' + assert render_subjects_mock.called + assert render_options_mock.called diff --git a/test/test_renderer/test_dbml/test_note.py b/test/test_renderer/test_dbml/test_note.py new file mode 100644 index 0000000..20df62d --- /dev/null +++ b/test/test_renderer/test_dbml/test_note.py @@ -0,0 +1,23 @@ +from pydbml.classes import Note +from pydbml.renderer.dbml.default.note import render_note +from pydbml.renderer.dbml.default.utils import prepare_text_for_dbml + + +def test_prepare_text_for_dbml() -> None: + note = Note("""Three quotes: ''', one quote: '.""") + assert prepare_text_for_dbml(note.text) == "Three quotes: \\''', one quote: \\'." + + +class TestRenderNote: + @staticmethod + def test_oneline() -> None: + note = Note("Note text") + assert render_note(note) == "Note {\n 'Note text'\n}" + + @staticmethod + def test_multiline() -> None: + note = Note("Note text\nwith multiple lines") + assert ( + render_note(note) + == "Note {\n '''\n Note text\n with multiple lines'''\n}" + ) diff --git a/test/test_renderer/test_dbml/test_project.py b/test/test_renderer/test_dbml/test_project.py new file mode 100644 index 0000000..c2d2b76 --- /dev/null +++ b/test/test_renderer/test_dbml/test_project.py @@ -0,0 +1,46 @@ +from pydbml.classes import Project, Note +from pydbml.renderer.dbml.default.project import render_items, render_project + + +class TestRenderItems: + @staticmethod + def test_oneline(): + project = Project(name="test", items={"key1": "value1"}) + assert render_items(project.items) == " key1: 'value1'\n" + + @staticmethod + def test_multiline(): + project = Project(name="test", items={"key1": "value1\nvalue2"}) + assert render_items(project.items) == " key1: '''value1\n value2'''\n" + + @staticmethod + def test_multiple(): + project = Project( + name="test", items={"key1": "value1", "key2": "value2\nnewline"} + ) + assert ( + render_items(project.items) + == " key1: 'value1'\n key2: '''value2\n newline'''\n" + ) + + +class TestRenderProject: + @staticmethod + def test_no_note() -> None: + project = Project(name="test", items={"key1": "value1"}) + expected = "Project \"test\" {\n key1: 'value1'\n}" + assert render_project(project) == expected + + @staticmethod + def test_note() -> None: + project = Project(name="test", items={"key1": "value1"}) + project.note = Note("Note text") + expected = ( + 'Project "test" {\n' + " key1: 'value1'\n" + " Note {\n" + " 'Note text'\n" + " }\n" + "}" + ) + assert render_project(project) == expected diff --git a/test/test_renderer/test_dbml/test_reference.py b/test/test_renderer/test_dbml/test_reference.py new file mode 100644 index 0000000..85fdac1 --- /dev/null +++ b/test/test_renderer/test_dbml/test_reference.py @@ -0,0 +1,123 @@ +from unittest.mock import patch + +import pytest + +from pydbml._classes.reference import Reference +from pydbml.exceptions import TableNotFoundError, DBMLError +from pydbml.renderer.dbml.default.reference import ( + validate_for_dbml, + render_inline_reference, + render_col, + render_options, + render_not_inline_reference, + render_reference, +) + + +class TestValidateFroDBML: + @staticmethod + def test_ok(reference1: Reference) -> None: + validate_for_dbml(reference1) + + @staticmethod + def test_no_table(reference1: Reference) -> None: + reference1.col2[0].table = None + with pytest.raises(TableNotFoundError): + validate_for_dbml(reference1) + + +class TestRenderInlineReference: + @staticmethod + def test_ok(reference1: Reference) -> None: + reference1.inline = True + assert render_inline_reference(reference1) == 'ref: > "products"."id"' + + @staticmethod + def test_composite(reference1: Reference) -> None: + reference1.col2.append(reference1.col2[0]) + with pytest.raises(DBMLError): + render_inline_reference(reference1) + + +class TestRendeCol: + @staticmethod + def test_single(reference1: Reference) -> None: + assert render_col(reference1.col2) == '"id"' + + @staticmethod + def test_multiple(reference1: Reference) -> None: + reference1.col2.append(reference1.col2[0]) + assert render_col(reference1.col2) == '("id", "id")' + + +class TestRenderOptions: + @staticmethod + def test_on_update(reference1: Reference) -> None: + reference1.on_update = "cascade" + assert render_options(reference1) == " [update: cascade]" + + @staticmethod + def test_on_delete(reference1: Reference) -> None: + reference1.on_delete = "set null" + assert render_options(reference1) == " [delete: set null]" + + @staticmethod + def test_both(reference1: Reference) -> None: + reference1.on_update = "cascade" + reference1.on_delete = "set null" + assert render_options(reference1) == " [update: cascade, delete: set null]" + + @staticmethod + def test_no_options(reference1: Reference) -> None: + assert render_options(reference1) == "" + + +class TestRenderNotInlineReference: + @staticmethod + def test_ok(reference1: Reference) -> None: + assert render_not_inline_reference(reference1) == ( + 'Ref {\n "orders"."product_id" > "products"."id"\n}' + ) + + @staticmethod + def test_comment(reference1: Reference) -> None: + reference1.comment = "comment" + assert render_not_inline_reference(reference1) == ( + '// comment\nRef {\n "orders"."product_id" > "products"."id"\n}' + ) + + @staticmethod + def test_name(reference1: Reference) -> None: + reference1.name = "ref_name" + assert render_not_inline_reference(reference1) == ( + 'Ref ref_name {\n "orders"."product_id" > "products"."id"\n}' + ) + + +class TestRenderReference: + @staticmethod + def test_inline(reference1: Reference) -> None: + reference1.inline = True + with patch( + "pydbml.renderer.dbml.default.reference.render_inline_reference", + return_value="inline", + ) as mock_render: + with patch( + "pydbml.renderer.dbml.default.reference.validate_for_dbml", + ) as mock_validate: + assert render_reference(reference1) == "inline" + assert mock_render.called + assert mock_validate.called + + @staticmethod + def test_not_inline(reference1: Reference) -> None: + with patch( + "pydbml.renderer.dbml.default.reference.render_not_inline_reference", + return_value="not inline", + ) as mock_render: + with patch( + "pydbml.renderer.dbml.default.reference.validate_for_dbml", + ) as mock_validate: + assert render_reference(reference1) == "not inline" + assert mock_render.called + assert mock_validate.called diff --git a/test/test_renderer/test_dbml/test_renderer.py b/test/test_renderer/test_dbml/test_renderer.py new file mode 100644 index 0000000..2e8acf6 --- /dev/null +++ b/test/test_renderer/test_dbml/test_renderer.py @@ -0,0 +1,20 @@ +from unittest.mock import Mock, patch + +from pydbml.renderer.dbml.default import DefaultDBMLRenderer + + +def test_render_db() -> None: + db = Mock( + project=Mock(), # #1 + refs=(Mock(inline=False), Mock(inline=False), Mock(inline=True)), # #2, #3 + tables=[Mock(), Mock(), Mock()], # #4, #5, #6 + enums=[Mock(), Mock()], # #7, #8 + table_groups=[Mock(), Mock()], # #9, #10 + sticky_notes=[Mock(), Mock()], # #11, #12 + ) + + with patch.object( + DefaultDBMLRenderer, "render", Mock(return_value="") + ) as render_mock: + DefaultDBMLRenderer.render_db(db) + assert render_mock.call_count == 12 diff --git a/test/test_renderer/test_dbml/test_sticky_note.py b/test/test_renderer/test_dbml/test_sticky_note.py new file mode 100644 index 0000000..1f5bc2a --- /dev/null +++ b/test/test_renderer/test_dbml/test_sticky_note.py @@ -0,0 +1,17 @@ +from pydbml._classes.sticky_note import StickyNote +from pydbml.renderer.dbml.default import render_sticky_note + + +class TestRenderNote: + @staticmethod + def test_oneline() -> None: + note = StickyNote(name='mynote', text="Note text") + assert render_sticky_note(note) == "Note mynote {\n 'Note text'\n}" + + @staticmethod + def test_multiline() -> None: + note = StickyNote(name='mynote', text="Note text\nwith multiple lines") + assert ( + render_sticky_note(note) + == "Note mynote {\n '''\n Note text\n with multiple lines'''\n}" + ) diff --git a/test/test_renderer/test_dbml/test_table.py b/test/test_renderer/test_dbml/test_table.py new file mode 100644 index 0000000..7843a5d --- /dev/null +++ b/test/test_renderer/test_dbml/test_table.py @@ -0,0 +1,109 @@ +from pydbml import Database +from pydbml.classes import Table, Index, Note +from pydbml.renderer.dbml.default.table import ( + get_full_name_for_dbml, + render_header, + render_indexes, + render_table, +) + + +class TestGetFullNameForDBML: + @staticmethod + def test_no_schema(table1: Table) -> None: + table1.schema = "public" + assert get_full_name_for_dbml(table1) == '"products"' + + @staticmethod + def test_with_schema(table1: Table) -> None: + table1.schema = "myschema" + assert get_full_name_for_dbml(table1) == '"myschema"."products"' + + +class TestRenderHeader: + @staticmethod + def test_simple(table1: Table) -> None: + expected = 'Table "products" ' + assert render_header(table1) == expected + + @staticmethod + def test_alias(table1: Table) -> None: + table1.alias = "p" + expected = 'Table "products" as "p" ' + assert render_header(table1) == expected + + @staticmethod + def test_header_color(table1: Table) -> None: + table1.header_color = "red" + expected = 'Table "products" [headercolor: red] ' + assert render_header(table1) == expected + + @staticmethod + def test_all(table1: Table) -> None: + table1.alias = "p" + table1.header_color = "red" + expected = 'Table "products" as "p" [headercolor: red] ' + assert render_header(table1) == expected + + +class TestRenderIndexes: + @staticmethod + def test_no_indexes(table1: Table) -> None: + assert render_indexes(table1) == "" + + @staticmethod + def test_one_index(index1: Index) -> None: + assert render_indexes(index1.table) == "\n indexes {\n name\n }\n" + + +class TestRenderTable: + @staticmethod + def test_simple(db: Database, table1: Table) -> None: + db.add(table1) + expected = 'Table "products" {\n "id" integer\n "name" varchar\n}' + assert render_table(table1) == expected + + @staticmethod + def test_note_and_comment(db: Database, table1: Table) -> None: + table1.comment = "Table comment" + table1.note = Note("Table note") + db.add(table1) + expected = ( + "// Table comment\n" + 'Table "products" {\n' + ' "id" integer\n' + ' "name" varchar\n' + " Note {\n" + " 'Table note'\n" + " }\n" + "}" + ) + assert render_table(table1) == expected + + @staticmethod + def test_properties(db: Database, table1: Table) -> None: + table1.properties = {"key": "value"} + db.add(table1) + db.allow_properties = True + expected = ( + 'Table "products" {\n' + ' "id" integer\n' + ' "name" varchar\n' + "\n" + " key: 'value'\n" + "}" + ) + assert render_table(table1) == expected + + @staticmethod + def test_properties_not_allowed(db: Database, table1: Table) -> None: + table1.properties = {"key": "value"} + db.add(table1) + db.allow_properties = False + expected = ( + 'Table "products" {\n' + ' "id" integer\n' + ' "name" varchar\n' + "}" + ) + assert render_table(table1) == expected diff --git a/test/test_renderer/test_dbml/test_table_group.py b/test/test_renderer/test_dbml/test_table_group.py new file mode 100644 index 0000000..5342b39 --- /dev/null +++ b/test/test_renderer/test_dbml/test_table_group.py @@ -0,0 +1,45 @@ +from pydbml._classes.note import Note +from pydbml._classes.table_group import TableGroup +from pydbml.classes import Table +from pydbml.renderer.dbml.default import render_table_group + + +class TestTableGroup: + @staticmethod + def test_simple(table1: Table, table2: Table, table3: Table) -> None: + tg = TableGroup( + name="mygroup", + items=[table1, table2, table3], + ) + expected = ( + 'TableGroup "mygroup" {\n' + ' "products"\n' + ' "products"\n' + ' "orders"\n' + '}' + ) + assert render_table_group(tg) == expected + + @staticmethod + def test_full(table1: Table, table2: Table, table3: Table) -> None: + tg = TableGroup( + name="mygroup", + items=[table1, table2, table3], + comment="My comment", + note=Note('Note line1\nNote line2'), + color='#FFF' + ) + expected = ( + '// My comment\n' + 'TableGroup "mygroup" [color: #FFF] {\n' + ' "products"\n' + ' "products"\n' + ' "orders"\n' + ' Note {\n' + " '''\n" + ' Note line1\n' + " Note line2'''\n" + ' }\n' + '}' + ) + assert render_table_group(tg) == expected diff --git a/test/test_renderer/test_dbml/test_utils.py b/test/test_renderer/test_dbml/test_utils.py new file mode 100644 index 0000000..11a2151 --- /dev/null +++ b/test/test_renderer/test_dbml/test_utils.py @@ -0,0 +1,20 @@ +from pydbml.classes import Note +from pydbml.renderer.dbml.default.utils import note_option_to_dbml, comment_to_dbml + + +class TestNoteOptionsToDBML: + @staticmethod + def test_oneline() -> None: + note = Note("One line note") + expected = "note: 'One line note'" + assert note_option_to_dbml(note) == expected + + @staticmethod + def test_multiline() -> None: + note = Note("Multiline\nnote") + expected = "note: '''Multiline\nnote'''" + assert note_option_to_dbml(note) == expected + + +def test_comment_to_dbml() -> None: + assert comment_to_dbml("Simple comment") == "// Simple comment\n" diff --git a/test/test_renderer/test_sql/__init__.py b/test/test_renderer/test_sql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_renderer/test_sql/test_default/__init__.py b/test/test_renderer/test_sql/test_default/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_renderer/test_sql/test_default/test_column.py b/test/test_renderer/test_sql/test_default/test_column.py new file mode 100644 index 0000000..05f9c20 --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_column.py @@ -0,0 +1,19 @@ +from pydbml.classes import Column +from pydbml.renderer.sql.default import render_column + + +class TestRenderColumn: + @staticmethod + def test_simple(simple_column: Column) -> None: + expected = '"id" integer' + + assert render_column(simple_column), expected + + @staticmethod + def test_complex(complex_column: Column) -> None: + expected = ( + "-- This is a counter column\n" + '"counter" "product status" PRIMARY KEY AUTOINCREMENT UNIQUE NOT NULL DEFAULT ' + "0" + ) + assert render_column(complex_column) == expected diff --git a/test/test_renderer/test_sql/test_default/test_enum.py b/test/test_renderer/test_sql/test_default/test_enum.py new file mode 100644 index 0000000..9a5d4ee --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_enum.py @@ -0,0 +1,39 @@ +from pydbml.classes import EnumItem, Enum +from pydbml.renderer.sql.default import render_enum, render_enum_item + + +class TestRenderEnumItem: + @staticmethod + def test_simple(enum_item1: EnumItem): + expected = "'en-US'," + assert render_enum_item(enum_item1) == expected + + @staticmethod + def test_comment(enum_item1: EnumItem): + enum_item1.comment = "Test comment" + expected = "-- Test comment\n'en-US'," + assert render_enum_item(enum_item1) == expected + + +class TestRenderEnum: + @staticmethod + def test_simple_enum(enum1: Enum) -> None: + expected = ( + 'CREATE TYPE "product status" AS ENUM (\n' + " 'production',\n" + " 'development'\n" + ");" + ) + assert render_enum(enum1) == expected + + @staticmethod + def test_comments(enum1: Enum) -> None: + enum1.comment = "Enum comment" + expected = ( + "-- Enum comment\n" + 'CREATE TYPE "product status" AS ENUM (\n' + " 'production',\n" + " 'development'\n" + ");" + ) + assert render_enum(enum1) == expected diff --git a/test/test_renderer/test_sql/test_default/test_expression.py b/test/test_renderer/test_sql/test_default/test_expression.py new file mode 100644 index 0000000..2d10c1f --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_expression.py @@ -0,0 +1,6 @@ +from pydbml.classes import Expression +from pydbml.renderer.sql.default import render_expression + + +def test_render_expression(expression1: Expression): + assert render_expression(expression1) == '(SUM(amount))' diff --git a/test/test_renderer/test_sql/test_default/test_index.py b/test/test_renderer/test_sql/test_default/test_index.py new file mode 100644 index 0000000..83859e7 --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_index.py @@ -0,0 +1,82 @@ +from pydbml.classes import Column, Expression, Index +from pydbml.renderer.sql.default.index import render_subject, render_index, render_pk + + +class TestRenderSubject: + @staticmethod + def test_column(simple_column: Column) -> None: + expected = '"id"' + assert render_subject(simple_column) == expected + + @staticmethod + def test_expression(expression1: Expression) -> None: + expected = "(SUM(amount))" + assert render_subject(expression1) == expected + + @staticmethod + def test_other() -> None: + expected = "test" + assert render_subject(expected) == expected + + +class TestRenderPK: + @staticmethod + def test_comment(index1: Index) -> None: + index1.comment = "Test comment" + expected = '-- Test comment\nPRIMARY KEY ("name")' + assert render_pk(index1, '"name"') == expected + + @staticmethod + def test_no_comment(index1: Index) -> None: + expected = 'PRIMARY KEY ("name")' + assert render_pk(index1, '"name"') == expected + + +class TestRenderComponents: + @staticmethod + def test_comment(index1: Index) -> None: + index1.comment = "Test comment" + expected = '-- Test comment\nCREATE INDEX ON "products" ("name");' + assert render_index(index1) == expected + + @staticmethod + def test_unique(index1: Index) -> None: + index1.unique = True + expected = 'CREATE UNIQUE INDEX ON "products" ("name");' + assert render_index(index1) == expected + + @staticmethod + def test_name(index1: Index) -> None: + index1.name = "test" + expected = 'CREATE INDEX "test" ON "products" ("name");' + assert render_index(index1) == expected + + @staticmethod + def test_no_table(index1: Index) -> None: + index1.table = None + expected = 'CREATE INDEX ("name");' + assert render_index(index1) == expected + + @staticmethod + def test_type(index1: Index) -> None: + index1.type = "hash" + expected = 'CREATE INDEX ON "products" USING HASH ("name");' + assert render_index(index1) == expected + + +class TestRenderIndex: + @staticmethod + def test_render_index(index1: Index) -> None: + index1.comment = "Test comment" + index1.unique = True + index1.name = "test" + index1.type = "hash" + + expected = '-- Test comment\nCREATE UNIQUE INDEX "test" ON "products" USING HASH ("name");' + assert render_index(index1) == expected + + @staticmethod + def test_render_pk(index1: Index) -> None: + index1.pk = True + expected = 'PRIMARY KEY ("name")' + assert render_index(index1) == expected diff --git a/test/test_renderer/test_sql/test_default/test_note.py b/test/test_renderer/test_sql/test_default/test_note.py new file mode 100644 index 0000000..fc3e11c --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_note.py @@ -0,0 +1,45 @@ +from textwrap import dedent + +from pydbml.classes import Note, Table, Index +from pydbml.renderer.sql.default.note import prepare_text_for_sql, generate_comment_on, render_note + + +def test_prepare_text_for_sql() -> None: + text = dedent( + """\ + First line break is preserved + second line break \\ + is 'ignored' """ + ) + expected = 'First line break is preserved\nsecond line break is "ignored" ' + assert prepare_text_for_sql(Note(text)) == expected + + +def test_generate_comment_on(note1: Note) -> None: + expected = "COMMENT ON TABLE \"table1\" IS 'Simple note';" + + assert generate_comment_on(note1, "Table", "table1") == expected + + +class TestRenderNote: + @staticmethod + def test_table_note_with_text(note1: Note, table1: Table) -> None: + table1.note = note1 + expected = "COMMENT ON TABLE \"products\" IS 'Simple note';" + assert render_note(note1) == expected + + @staticmethod + def test_table_note_without_text(note1: Note, table1: Table) -> None: + table1.note = note1 + note1.text = "" + assert render_note(note1) == "" + + @staticmethod + def test_index_note(index1: Index, multiline_note: Note) -> None: + index1.note = multiline_note + expected = dedent( + """\ + -- This is a multiline note. + -- It has multiple lines.""" + ) + assert render_note(multiline_note) == expected diff --git a/test/test_renderer/test_sql/test_default/test_reference.py b/test/test_renderer/test_sql/test_default/test_reference.py new file mode 100644 index 0000000..14bfced --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_reference.py @@ -0,0 +1,153 @@ +from textwrap import dedent +from unittest.mock import patch + +import pytest + +from pydbml.classes import Table, Reference +from pydbml.exceptions import TableNotFoundError +from pydbml.renderer.sql.default.reference import ( + validate_for_sql, + col_names, + generate_inline_sql, + generate_not_inline_sql, + generate_many_to_many_sql, + render_reference, +) + + +def test_col_names(table1: Table) -> None: + assert col_names(table1.columns) == '"id", "name"' + + +class TestValidateForSQL: + @staticmethod + def test_ok(reference1: Reference) -> None: + validate_for_sql(reference1) + + @staticmethod + def test_faulty(reference1: Reference) -> None: + reference1.col2[0].table = None + with pytest.raises(TableNotFoundError): + validate_for_sql(reference1) + + +class TestGenerateInlineSQL: + @staticmethod + def test_simple(reference1: Reference) -> None: + expected = '{c}FOREIGN KEY ("product_id") REFERENCES "products" ("id")' + assert ( + generate_inline_sql( + reference1, source_col=reference1.col1, ref_col=reference1.col2 + ) + == expected + ) + + @staticmethod + def test_on_update_on_delete(reference1: Reference) -> None: + reference1.on_update = "cascade" + reference1.on_delete = "set null" + expected = '{c}FOREIGN KEY ("product_id") REFERENCES "products" ("id") ON UPDATE CASCADE ON DELETE SET NULL' + assert ( + generate_inline_sql( + reference1, source_col=reference1.col1, ref_col=reference1.col2 + ) + == expected + ) + + +class TestGenerateNotInlineSQL: + @staticmethod + def test_simple(reference1: Reference) -> None: + expected = 'ALTER TABLE "orders" ADD {c}FOREIGN KEY ("product_id") REFERENCES "products" ("id");' + assert ( + generate_not_inline_sql( + reference1, source_col=reference1.col1, ref_col=reference1.col2 + ) + == expected + ) + + @staticmethod + def test_on_update_on_delete(reference1: Reference) -> None: + reference1.on_update = "cascade" + reference1.on_delete = "set null" + expected = 'ALTER TABLE "orders" ADD {c}FOREIGN KEY ("product_id") REFERENCES "products" ("id") ON UPDATE CASCADE ON DELETE SET NULL;' + assert ( + generate_not_inline_sql( + reference1, source_col=reference1.col1, ref_col=reference1.col2 + ) + == expected + ) + + +def test_generate_many_to_many_sql(reference1: Reference) -> None: + reference1.type = "<>" + expected = dedent( + """\ + CREATE TABLE "orders_products" ( + "orders_product_id" integer NOT NULL, + "products_id" integer NOT NULL, + PRIMARY KEY ("orders_product_id", "products_id") + ); + + ALTER TABLE "orders_products" ADD FOREIGN KEY ("orders_product_id") REFERENCES "orders" ("product_id"); + + ALTER TABLE "orders_products" ADD FOREIGN KEY ("products_id") REFERENCES "products" ("id");""" + ) + assert generate_many_to_many_sql(reference1) == expected + + +class TestRenderReference: + @staticmethod + def test_many_to_many(reference1: Reference) -> None: + reference1.type = "<>" + with patch( + "pydbml.renderer.sql.default.reference.generate_many_to_many_sql" + ) as mock: + render_reference(reference1) + mock.assert_called_once_with(reference1) + + @staticmethod + def test_inline_to_one(reference1: Reference) -> None: + reference1.type = ">" + reference1.inline = True + with patch("pydbml.renderer.sql.default.reference.generate_inline_sql") as mock: + render_reference(reference1) + mock.assert_called_once_with( + model=reference1, source_col=reference1.col1, ref_col=reference1.col2 + ) + reference1.type = "-" + render_reference(reference1) + assert mock.call_count == 2 + + @staticmethod + def test_inline_to_many(reference1: Reference) -> None: + reference1.type = "<" + reference1.inline = True + with patch("pydbml.renderer.sql.default.reference.generate_inline_sql") as mock: + render_reference(reference1) + mock.assert_called_once_with( + model=reference1, source_col=reference1.col2, ref_col=reference1.col1 + ) + + @staticmethod + def test_not_inline_to_one(reference1: Reference) -> None: + reference1.type = ">" + reference1.inline = False + with patch("pydbml.renderer.sql.default.reference.generate_not_inline_sql") as mock: + render_reference(reference1) + mock.assert_called_once_with( + model=reference1, source_col=reference1.col1, ref_col=reference1.col2 + ) + reference1.type = "-" + render_reference(reference1) + assert mock.call_count == 2 + + @staticmethod + def test_not_inline_to_many(reference1: Reference) -> None: + reference1.type = "<" + reference1.inline = False + with patch("pydbml.renderer.sql.default.reference.generate_not_inline_sql") as mock: + render_reference(reference1) + mock.assert_called_once_with( + model=reference1, source_col=reference1.col2, ref_col=reference1.col1 + ) diff --git a/test/test_renderer/test_sql/test_default/test_renderer.py b/test/test_renderer/test_sql/test_default/test_renderer.py new file mode 100644 index 0000000..7b7606d --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_renderer.py @@ -0,0 +1,29 @@ +from unittest.mock import Mock, patch + +from pydbml.renderer.sql.default import DefaultSQLRenderer + + +def test_render() -> None: + model = Mock() + result = DefaultSQLRenderer.render(model) + assert model.check_attributes_for_sql.called + assert result == "" + + +def test_render_db() -> None: + db = Mock( + refs=(Mock(inline=False), Mock(inline=False), Mock(inline=True)), + tables=[Mock(), Mock(), Mock()], + enums=[Mock(), Mock()], + ) + + with patch( + "pydbml.renderer.sql.default.renderer.reorder_tables_for_sql", + Mock(return_value=db.tables), + ) as reorder_mock: + with patch.object( + DefaultSQLRenderer, "render", Mock(return_value="") + ) as render_mock: + result = DefaultSQLRenderer.render_db(db) + assert reorder_mock.called + assert render_mock.call_count == 7 diff --git a/test/test_renderer/test_sql/test_default/test_table.py b/test/test_renderer/test_sql/test_default/test_table.py new file mode 100644 index 0000000..1504ef0 --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_table.py @@ -0,0 +1,215 @@ +from typing import Tuple +from unittest.mock import Mock, patch + +import pytest + +import pydbml.renderer.sql.default.table +from pydbml import Database +from pydbml.classes import Table, Column, Reference, Note +from pydbml.exceptions import UnknownDatabaseError +from pydbml.renderer.sql.default.table import ( + get_references_for_sql, + get_inline_references_for_sql, + create_components, + render_column_notes, + create_body, +) + + +@pytest.fixture +def db1(): + return Database() + + +@pytest.fixture +def table1(db1: Database) -> Table: + t = Table( + name="products", + columns=[ + Column("id", "integer", pk=True), + Column("name", "varchar"), + ], + ) + db1.add(t) + return t + + +@pytest.fixture +def table2(db1: Database) -> Table: + t = Table( + name="names", + columns=[ + Column("id", "integer"), + Column("name_val", "varchar"), + ], + ) + db1.add(t) + return t + + +@pytest.fixture +def not_inline_refs( + db1: Database, table1: Table, table2: Table +) -> Tuple[Reference, Reference, Reference]: + r1 = Reference(">", table1[1], table2[1], inline=False) + r2 = Reference("-", table1[0], table2[0], inline=False) + r3 = Reference("<", table1[0], table2[1], inline=False) + db1.add(r1) + db1.add(r2) + db1.add(r3) + return r1, r2, r3 + + +@pytest.fixture +def inline_refs( + db1: Database, table1: Table, table2: Table +) -> Tuple[Reference, Reference, Reference]: + r1 = Reference(">", table1[1], table2[1], inline=True) + r2 = Reference("-", table1[0], table2[0], inline=True) + r3 = Reference("<", table1[0], table2[1], inline=True) + db1.add(r1) + db1.add(r2) + db1.add(r3) + return r1, r2, r3 + + +class TestGetReferencesForSQL: + @staticmethod + def test_get_references_for_sql_not_inline( + table1: Table, table2: Table, not_inline_refs + ) -> None: + r1, r2, r3 = not_inline_refs + assert get_references_for_sql(table1) == [r1, r2] + assert get_references_for_sql(table2) == [r3] + + @staticmethod + def test_get_references_for_sql_inline( + table1: Table, table2: Table, inline_refs + ) -> None: + r1, r2, r3 = inline_refs + assert get_references_for_sql(table1) == [r1, r2] + assert get_references_for_sql(table2) == [r3] + + @staticmethod + def test_db_not_set(table1: Table) -> None: + table1.database = None + with pytest.raises(UnknownDatabaseError): + get_references_for_sql(table1) + + +class TestGetInlineReferencesForSQL: + @staticmethod + def test_inline(table1: Table, table2: Table, inline_refs) -> None: + r1, r2, r3 = inline_refs + assert get_inline_references_for_sql(table1) == [r1, r2] + assert get_inline_references_for_sql(table2) == [r3] + + @staticmethod + def test_not_inline(table1: Table, table2: Table, not_inline_refs) -> None: + assert get_inline_references_for_sql(table1) == [] + assert get_inline_references_for_sql(table2) == [] + + @staticmethod + def test_abstract(table1: Table, table2: Table, inline_refs) -> None: + table1.abstract = table2.abstract = True + assert get_inline_references_for_sql(table1) == [] + assert get_inline_references_for_sql(table2) == [] + + +class TestCreateBody: + @staticmethod + def test_create_body() -> None: + table = Mock( + columns=[Mock(), Mock()], + indexes=[Mock(pk=True), Mock(pk=False)], + ) + with patch( + "pydbml.renderer.sql.default.table.get_inline_references_for_sql", + Mock(return_value=[Mock()]), + ) as get_inline_mock: + with patch( + "pydbml.renderer.sql.default.renderer.DefaultSQLRenderer.render", + Mock(return_value=""), + ) as render_mock: + create_body(table) + assert get_inline_mock.called + assert render_mock.call_count == 4 + + @staticmethod + def test_composite_pk(table1: Table) -> None: + table1.add_column(Column("id2", "integer", pk=True)) + expected = ( + ' "id" integer,\n' + ' "name" varchar,\n' + ' "id2" integer,\n' + ' PRIMARY KEY ("id", "id2")' + ) + assert create_body(table1) == expected + + +class TestCreateComponents: + @staticmethod + def test_simple(table1: Table) -> None: + with patch( + "pydbml.renderer.sql.default.table.create_body", Mock(return_value="body") + ) as create_body_mock: + expected = 'CREATE TABLE "products" (\nbody\n);' + assert create_components(table1) == expected + + @staticmethod + def test_comment(table1: Table) -> None: + table1.comment = "Simple comment" + with patch( + "pydbml.renderer.sql.default.table.create_body", Mock(return_value="body") + ) as create_body_mock: + expected = '-- Simple comment\n\nCREATE TABLE "products" (\nbody\n);' + assert create_components(table1) == expected + + @staticmethod + def test_indexes(table1: Table) -> None: + table1.indexes = [Mock(pk=False), Mock(pk=True)] + with patch( + "pydbml.renderer.sql.default.table.create_body", Mock(return_value="body") + ) as create_body_mock: + with patch( + "pydbml.renderer.sql.default.renderer.DefaultSQLRenderer.render", + Mock(return_value="index"), + ) as render_mock: + expected = 'CREATE TABLE "products" (\nbody\n);\n\nindex' + assert create_components(table1) == expected + + +class TestRenderColumnNotes: + @staticmethod + def test_notes(table1: Table) -> None: + table1.columns[0].note = Note("First column note") + table1.columns[1].note = Note("Second column note") + expected = ( + "\n" + "\n" + 'COMMENT ON COLUMN "products"."id" IS \'First column note\';\n' + "\n" + 'COMMENT ON COLUMN "products"."name" IS \'Second column note\';' + ) + assert render_column_notes(table1) == expected + + @staticmethod + def test_no_notes(table1: Table) -> None: + assert render_column_notes(table1) == "" + + +def test_render_table(table1: Table) -> None: + table1.note = Mock(sql="-- Simple note") + with patch( + "pydbml.renderer.sql.default.table.create_components", + Mock(return_value="components"), + ) as create_components_mock: + with patch( + "pydbml.renderer.sql.default.table.render_column_notes", + Mock(return_value="\n\ncolumn notes"), + ) as render_column_notes_mock: + assert pydbml.renderer.sql.default.table.render_table(table1) == ( + "components\n\n-- Simple note\n\ncolumn notes" + ) + assert create_components_mock.called + assert render_column_notes_mock.called diff --git a/test/test_renderer/test_sql/test_default/test_utils.py b/test/test_renderer/test_sql/test_default/test_utils.py new file mode 100644 index 0000000..82c3477 --- /dev/null +++ b/test/test_renderer/test_sql/test_default/test_utils.py @@ -0,0 +1,50 @@ +from unittest.mock import Mock + +from pydbml.classes import Enum +from pydbml.constants import ONE_TO_MANY, MANY_TO_ONE, MANY_TO_MANY +from pydbml.renderer.sql.default.utils import ( + get_full_name_for_sql, + reorder_tables_for_sql, +) + + +class TestGetFullNameForSQL: + @staticmethod + def test_public(enum1: Enum) -> None: + assert get_full_name_for_sql(enum1) == '"product status"' + + @staticmethod + def test_schema(enum1: Enum) -> None: + enum1.schema = "myschema" + assert get_full_name_for_sql(enum1) == '"myschema"."product status"' + + +def test_reorder_tables() -> None: + t1 = Mock(name="table1") # 1 ref + t2 = Mock(name="table2") # 2 refs + t3 = Mock(name="table3") + t4 = Mock(name="table4") # 1 ref + t5 = Mock(name="table5") + t6 = Mock(name="table6") # 3 refs + t7 = Mock(name="table7") + t8 = Mock(name="table8") + t9 = Mock(name="table9") + t10 = Mock(name="table10") + + refs = [ + Mock(type=ONE_TO_MANY, table1=t1, table2=t2, inline=True), + Mock(type=MANY_TO_ONE, table1=t4, table2=t3, inline=True), + Mock(type=ONE_TO_MANY, table1=t6, table2=t2, inline=True), + Mock(type=ONE_TO_MANY, table1=t7, table2=t6, inline=True), + Mock(type=MANY_TO_ONE, table1=t6, table2=t8, inline=True), + Mock(type=ONE_TO_MANY, table1=t9, table2=t6, inline=True), + Mock( + type=ONE_TO_MANY, table1=t1, table2=t2, inline=False + ), # ignored not inline + Mock(type=ONE_TO_MANY, table1=t10, table2=t1, inline=True), + Mock(type=MANY_TO_MANY, table1=t1, table2=t2, inline=True), # ignored m2m + ] + original = [t1, t2, t3, t4, t5, t6, t7, t8, t9, t10] + expected = [t6, t2, t1, t4, t3, t5, t7, t8, t9, t10] + result = reorder_tables_for_sql(original, refs) # type: ignore + assert expected == result diff --git a/test/test_table.py b/test/test_table.py deleted file mode 100644 index 02983a6..0000000 --- a/test/test_table.py +++ /dev/null @@ -1,182 +0,0 @@ -from unittest import TestCase - -from pyparsing import ParseException -from pyparsing import ParseSyntaxException -from pyparsing import ParserElement - -from pydbml.definitions.table import alias -from pydbml.definitions.table import header_color -from pydbml.definitions.table import table -from pydbml.definitions.table import table_body -from pydbml.definitions.table import table_settings - - -ParserElement.setDefaultWhitespaceChars(' \t\r') - - -class TestAlias(TestCase): - def test_ok(self) -> None: - val = 'as Alias' - alias.parseString(val, parseAll=True) - - def test_nok(self) -> None: - val = 'asalias' - with self.assertRaises(ParseSyntaxException): - alias.parseString(val, parseAll=True) - - -class TestHeaderColor(TestCase): - def test_oneline(self) -> None: - val = 'headercolor: #CCCCCC' - res = header_color.parseString(val, parseAll=True) - self.assertEqual(res['header_color'], '#CCCCCC') - - def test_multiline(self) -> None: - val = 'headercolor:\n\n#E02' - res = header_color.parseString(val, parseAll=True) - self.assertEqual(res['header_color'], '#E02') - - -class TestTableSettings(TestCase): - def test_one(self) -> None: - val = '[headercolor: #E024DF]' - res = table_settings.parseString(val, parseAll=True) - self.assertEqual(res[0]['header_color'], '#E024DF') - - def test_both(self) -> None: - val = '[note: "note content", headercolor: #E024DF]' - res = table_settings.parseString(val, parseAll=True) - self.assertEqual(res[0]['header_color'], '#E024DF') - self.assertIn('note', res[0]) - - -class TestTableBody(TestCase): - def test_one_column(self) -> None: - val = 'id integer [pk, increment]\n' - res = table_body.parseString(val, parseAll=True) - self.assertEqual(len(res['columns']), 1) - - def test_two_columns(self) -> None: - val = 'id integer [pk, increment]\nname string\n' - res = table_body.parseString(val, parseAll=True) - self.assertEqual(len(res['columns']), 2) - - def test_columns_indexes(self) -> None: - val = ''' -id integer -country varchar [NOT NULL, ref: > countries.country_name] -booking_date date unique pk -indexes { - (id, country) [pk] // composite primary key -}''' - res = table_body.parseString(val, parseAll=True) - self.assertEqual(len(res['columns']), 3) - self.assertEqual(len(res['indexes']), 1) - - def test_columns_indexes_note(self) -> None: - val = ''' -id integer -country varchar [NOT NULL, ref: > countries.country_name] -booking_date date unique pk -note: 'mynote' -indexes { - (id, country) [pk] // composite primary key -}''' - res = table_body.parseString(val, parseAll=True) - self.assertEqual(len(res['columns']), 3) - self.assertEqual(len(res['indexes']), 1) - self.assertIsNotNone(res['note']) - val2 = ''' -id integer -country varchar [NOT NULL, ref: > countries.country_name] -booking_date date unique pk -note { - 'mynote' -} -indexes { - (id, country) [pk] // composite primary key -}''' - res2 = table_body.parseString(val2, parseAll=True) - self.assertEqual(len(res2['columns']), 3) - self.assertEqual(len(res2['indexes']), 1) - self.assertIsNotNone(res2['note']) - - def test_no_columns(self) -> None: - val = ''' -note: 'mynote' -indexes { - (id, country) [pk] // composite primary key -}''' - with self.assertRaises(ParseException): - table_body.parseString(val, parseAll=True) - - def test_columns_after_indexes(self) -> None: - val = ''' -note: 'mynote' -indexes { - (id, country) [pk] // composite primary key -} -id integer''' - with self.assertRaises(ParseException): - table_body.parseString(val, parseAll=True) - - -class TestTable(TestCase): - def test_simple(self) -> None: - val = 'table ids {\nid integer\n}' - res = table.parseString(val, parseAll=True) - self.assertEqual(res[0].name, 'ids') - self.assertEqual(len(res[0].columns), 1) - - def test_with_alias(self) -> None: - val = 'table ids as ii {\nid integer\n}' - res = table.parseString(val, parseAll=True) - self.assertEqual(res[0].name, 'ids') - self.assertEqual(res[0].alias, 'ii') - self.assertEqual(len(res[0].columns), 1) - - def test_with_settings(self) -> None: - val = 'table ids as ii [headercolor: #ccc, note: "headernote"] {\nid integer\n}' - res = table.parseString(val, parseAll=True) - self.assertEqual(res[0].name, 'ids') - self.assertEqual(res[0].alias, 'ii') - self.assertEqual(res[0].header_color, '#ccc') - self.assertEqual(res[0].note.text, 'headernote') - self.assertEqual(len(res[0].columns), 1) - - def test_with_body_note(self) -> None: - val = ''' -table ids as ii [ - headercolor: #ccc, - note: "headernote"] -{ - id integer - note: "bodynote" -}''' - res = table.parseString(val, parseAll=True) - self.assertEqual(res[0].name, 'ids') - self.assertEqual(res[0].alias, 'ii') - self.assertEqual(res[0].header_color, '#ccc') - self.assertEqual(res[0].note.text, 'bodynote') - self.assertEqual(len(res[0].columns), 1) - - def test_with_indexes(self) -> None: - val = ''' -table ids as ii [ - headercolor: #ccc, - note: "headernote"] -{ - id integer - country varchar - note: "bodynote" - indexes { - (id, country) [pk] // composite primary key - } -}''' - res = table.parseString(val, parseAll=True) - self.assertEqual(res[0].name, 'ids') - self.assertEqual(res[0].alias, 'ii') - self.assertEqual(res[0].header_color, '#ccc') - self.assertEqual(res[0].note.text, 'bodynote') - self.assertEqual(len(res[0].columns), 2) - self.assertEqual(len(res[0].indexes), 1) diff --git a/test/test_table_group.py b/test/test_table_group.py deleted file mode 100644 index 7e81251..0000000 --- a/test/test_table_group.py +++ /dev/null @@ -1,28 +0,0 @@ -from unittest import TestCase - -from pyparsing import ParserElement - -from pydbml.definitions.table_group import table_group - - -ParserElement.setDefaultWhitespaceChars(' \t\r') - - -class TestProject(TestCase): - def test_empty(self) -> None: - val = 'TableGroup name {}' - res = table_group.parseString(val, parseAll=True) - self.assertEqual(res[0].name, 'name') - - def test_fields(self) -> None: - val = "TableGroup name {table1 table2}" - res = table_group.parseString(val, parseAll=True) - self.assertEqual(res[0].name, 'name') - self.assertEqual(res[0].items, ['table1', 'table2']) - - def test_comment(self) -> None: - val = "//comment before\nTableGroup name\n{\ntable1\ntable2\n}" - res = table_group.parseString(val, parseAll=True) - self.assertEqual(res[0].name, 'name') - self.assertEqual(res[0].items, ['table1', 'table2']) - self.assertEqual(res[0].comment, 'comment before') diff --git a/test/test_tools.py b/test/test_tools.py new file mode 100644 index 0000000..a452fb2 --- /dev/null +++ b/test/test_tools.py @@ -0,0 +1,131 @@ +from unittest import TestCase + +import pytest + +from pydbml.classes import Note +from pydbml.tools import remove_indentation, doublequote_string +from pydbml.renderer.sql.default.utils import comment_to_sql +from pydbml.tools import indent +from pydbml.renderer.dbml.default.utils import note_option_to_dbml, comment_to_dbml +from pydbml.tools import strip_empty_lines + + +class TestCommentToDBML(TestCase): + def test_comment(self) -> None: + oneline = "comment" + self.assertEqual(f"// {oneline}\n", comment_to_dbml(oneline)) + + expected = """// +// line1 +// line2 +// line3 +// +""" + source = "\nline1\nline2\nline3\n" + self.assertEqual(comment_to_dbml(source), expected) + + +class TestCommentToSQL(TestCase): + def test_comment(self) -> None: + oneline = "comment" + self.assertEqual(f"-- {oneline}\n", comment_to_sql(oneline)) + + expected = """-- +-- line1 +-- line2 +-- line3 +-- +""" + source = "\nline1\nline2\nline3\n" + self.assertEqual(comment_to_sql(source), expected) + + +class TestNoteOptionToDBML(TestCase): + def test_oneline(self) -> None: + note = Note("one line note") + self.assertEqual(f"note: 'one line note'", note_option_to_dbml(note)) + + def test_oneline_with_quote(self) -> None: + note = Note("one line'd note") + self.assertEqual(f"note: 'one line\\'d note'", note_option_to_dbml(note)) + + def test_multiline(self) -> None: + note = Note("line1\nline2\nline3") + expected = "note: '''line1\nline2\nline3'''" + self.assertEqual(expected, note_option_to_dbml(note)) + + def test_multiline_with_quotes(self) -> None: + note = Note("line1\n'''line2\nline3") + expected = "note: '''line1\n\\'''line2\nline3'''" + self.assertEqual(expected, note_option_to_dbml(note)) + + +class TestIndent(TestCase): + def test_empty(self) -> None: + self.assertEqual(indent(""), "") + + def test_nonempty(self) -> None: + oneline = "one line text" + self.assertEqual(indent(oneline), f" {oneline}") + source = "line1\nline2\nline3" + expected = " line1\n line2\n line3" + self.assertEqual(indent(source), expected) + expected2 = " line1\n line2\n line3" + self.assertEqual(indent(source, 2), expected2) + + +class TestStripEmptyLines(TestCase): + def test_empty(self) -> None: + source = "" + self.assertEqual(strip_empty_lines(source), source) + + def test_no_empty_lines(self) -> None: + source = "line1\n\n\nline2" + self.assertEqual(strip_empty_lines(source), source) + + def test_empty_lines(self) -> None: + stripped = " line1\n\n line2" + source = f"\n \n \n\t \t \n \n{stripped}\n\n\n \n \t \n\t \n \n" + self.assertEqual(strip_empty_lines(source), stripped) + + def test_one_empty_line(self) -> None: + stripped = " line1\n\n line2" + source = f"\n{stripped}" + self.assertEqual(strip_empty_lines(source), stripped) + source = f"{stripped}\n" + self.assertEqual(strip_empty_lines(source), stripped) + + def test_end(self) -> None: + stripped = " line1\n\n line2" + source = f"\n{stripped}\n " + self.assertEqual(strip_empty_lines(source), stripped) + + +class TestRemoveIndentation(TestCase): + def test_empty(self) -> None: + source = "" + self.assertEqual(remove_indentation(source), source) + + def test_not_empty(self) -> None: + source = " line1\n line2" + expected = "line1\n line2" + self.assertEqual(remove_indentation(source), expected) + + +class TestDoublequoteString: + @staticmethod + @pytest.mark.parametrize( + "source,expected", + [ + ("Test string", '"Test string"'), + ('String with "quotes"!', '"String with \\"quotes\\"!"'), + ('"Quoted string"', '"Quoted string"'), + ], + ) + def test_oneline(source: str, expected: str) -> None: + assert doublequote_string(source) == expected + + @staticmethod + def test_multiline() -> None: + with pytest.raises(ValueError): + doublequote_string('line1\nline2') diff --git a/test/utils.py b/test/utils.py new file mode 100644 index 0000000..5d1163d --- /dev/null +++ b/test/utils.py @@ -0,0 +1,12 @@ +from unittest.mock import Mock +from typing import Optional + + +DEFAULT_OPTIONS = { + 'reformat_notes': True, +} + +def mock_parser(options: Optional[dict] = None): + if options is None: + options = dict(DEFAULT_OPTIONS) + return Mock(options=options) diff --git a/test_schema.dbml b/test_schema.dbml index 2739c7d..3d0b481 100644 --- a/test_schema.dbml +++ b/test_schema.dbml @@ -17,7 +17,10 @@ Enum "product status" { Table "orders" [headercolor: #fff] { "id" int [pk, increment] - "user_id" int [unique, not null] + "user_id" int [ + unique, + not null + ] "status" orders_status "created_at" varchar } @@ -36,7 +39,6 @@ Table "products" { "status" "product status" "created_at" datetime [default: `now()`] - Indexes { (merchant_id, status) [name: "product_status"] id [type: hash, unique] @@ -55,9 +57,10 @@ Table "users" { Ref:"orders"."id" < "order_items"."order_id" -TableGroup g1 { +TableGroup g1 [note: 'test note', color: #FFF] { users merchants + note: 'test note 2' } TableGroup g2 { @@ -74,7 +77,7 @@ Table "merchants" { } -Ref:"products"."id" < "order_items"."product_id" +Ref:"products"."id" < "order_items"."product_id" [update: set default, delete: set null] Ref:"countries"."code" < "users"."country_code" @@ -89,3 +92,14 @@ Table "countries" { "name" varchar "continent_name" varchar } + +Note sticky_note1 { + 'One line note' +} + +Note sticky_note2 { + ''' + # Title + body + ''' +}