diff --git a/sql/tools/README.md b/sql/tools/README.md index 5bc99fadb..887cc87a1 100644 --- a/sql/tools/README.md +++ b/sql/tools/README.md @@ -50,8 +50,16 @@ TODO 暂未支持 使用方式如下: -```Bash -python3 convertor.py +安装依赖库 + +```bash +pip install simple-ddl-parser ``` -然后,TODO \ No newline at end of file +执行如下命令打印生成 postgresql 的脚本内容,其他可选参数有:oracle, sqlserver + +```Bash +python3 convertor.py postgres +``` + +程序将sql脚本打印到终端,可以重定向到临时文件tmp.sql, 确认无误后可以利用IDEA(专业版)进行格式化。 diff --git a/sql/tools/convertor.py b/sql/tools/convertor.py index d5e75647f..9f9cd955c 100644 --- a/sql/tools/convertor.py +++ b/sql/tools/convertor.py @@ -6,11 +6,12 @@ Author: dhb52 (https://gitee.com/dhb52) pip install simple-ddl-parser """ +import argparse import pathlib import re import time from abc import ABC, abstractmethod -from typing import Dict, Tuple +from typing import Dict, Generator, Optional, Tuple, Union from simple_ddl_parser import DDLParser @@ -60,12 +61,12 @@ class Convertor(ABC): self.table_script_list = re.findall(r"CREATE TABLE [^;]*;", self.content) @abstractmethod - def translate_type(self, type: str, size: None | int | Tuple[int]) -> str: + def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]) -> str: """字段类型转换 Args: type (str): 字段类型 - size (None | int | Tuple[int]): 字段长度描述, 如varchar(255), decimal(10,2) + size (Optional[Union[int, Tuple[int]]]): 字段长度描述, 如varchar(255), decimal(10,2) Returns: str: 类型定义 @@ -97,7 +98,7 @@ class Convertor(ABC): pass @abstractmethod - def gen_index(self, table_ddl: Dict) -> str: + def gen_index(self, ddl: Dict) -> str: """生成索引定义 Args: @@ -133,6 +134,55 @@ class Convertor(ABC): """ pass + @staticmethod + def inserts(table_name: str, script_content: str) -> Generator: + PREFIX = f"INSERT INTO `{table_name}`" + + # 收集 `table_name` 对应的 insert 语句 + for line in script_content.split("\n"): + if line.startswith(PREFIX): + head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1) + head = head.strip().replace("`", "").lower() + tail = tail.strip().replace(r"\"", '"') + # tail = tail.replace("b'0'", "'0'").replace("b'1'", "'1'") + yield f"INSERT INTO {table_name.lower()} {head} VALUES {tail}" + + @staticmethod + def index(ddl: Dict) -> Generator: + """生成索引定义 + + Args: + ddl (Dict): 表DDL + + Yields: + Generator[str]: create index 语句 + """ + + def generate_columns(columns): + keys = [ + f"{col['name'].lower()}{' ' + col['order'].lower() if col['order'] != 'ASC' else ''}" + for col in columns[0] + ] + return ", ".join(keys) + + for no, index in enumerate(ddl["index"], 1): + columns = generate_columns(index["columns"]) + table_name = ddl["table_name"].lower() + yield f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns})" + + @staticmethod + def filed_comments(table_sql: str) -> Generator: + for line in table_sql.split("\n"): + match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip()) + if match: + field = match.group(1) + comment_string = match.group(2).replace("\\n", "\n") + yield field, comment_string + + def table_comment(self, table_sql: str) -> str: + match = re.search(r"COMMENT \= '([^']+)';", table_sql) + return match.group(1) if match else None + def print(self): """打印转换后的sql脚本到终端""" print( @@ -192,7 +242,7 @@ class PostgreSQLConvertor(Convertor): def __init__(self, src): super().__init__(src, "PostgreSQL") - def translate_type(self, type, size): + def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]): """类型转换""" type = type.lower() @@ -234,27 +284,30 @@ class PostgreSQLConvertor(Convertor): table_name = ddl["table_name"].lower() columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]] + filed_def_list = ",\n ".join(columns) script = f"""-- ---------------------------- -- Table structure for {table_name} -- ---------------------------- DROP TABLE IF EXISTS {table_name}; CREATE TABLE {table_name} ( - {',\n '.join(columns)} + {filed_def_list} );""" return script - def gen_comment(self, table_sql, table_name) -> str: + def gen_index(self, ddl: Dict) -> str: + return "\n".join(f"{script};" for script in self.index(ddl)) + + def gen_comment(self, table_sql: str, table_name: str) -> str: """生成字段及表的注释""" script = "" - for line in table_sql.split("\n"): - match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip()) - if match: - script += f"COMMENT ON COLUMN {table_name}.{match.group(1)} IS '{match.group(2).replace('\\n', '\n')}';\n" + for field, comment_string in self.filed_comments(table_sql): + script += ( + f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n" + ) - match = re.search(r"COMMENT \= '([^']+)';", table_sql) - table_comment = match.group(1) if match else None + table_comment = self.table_comment(table_sql) if table_comment: script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n" @@ -264,53 +317,21 @@ CREATE TABLE {table_name} ( """生成主键定义""" return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n" - def gen_index(self, ddl) -> str: - """生成 index""" - - def generate_columns(columns): - keys = [ - f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}" - for col in columns[0] - ] - return ", ".join(keys) - - script = "" - for no, index in enumerate(ddl["index"], 1): - columns = generate_columns(index["columns"]) - table_name = ddl["table_name"].lower() - script += ( - f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns});\n" - ) - - return script - - def gen_insert(self, table_name) -> str: + def gen_insert(self, table_name: str) -> str: """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence""" - PREFIX = f"INSERT INTO `{table_name}`" - - # 收集 `table_name` 对应的 insert 语句 - inserts = [] - for line in self.content.split("\n"): - if line.startswith(PREFIX): - head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1) - head = head.strip().replace("`", "").lower() - tail = tail.strip().replace(r"\"", '"') - script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}" - # bit(1)数据转换 - script = script.replace("b'0'", "'0'").replace("b'1'", "'1'") - inserts.append(script) - + inserts = list(Convertor.inserts(table_name, self.content)) ## 生成 insert 脚本 script = "" last_id = 0 if inserts: + inserts_lines = "\n".join(inserts) script += f"""\n\n-- ---------------------------- -- Records of {table_name.lower()} -- ---------------------------- -- @formatter:off BEGIN; -{'\n'.join(inserts)} +{inserts_lines} COMMIT; -- @formatter:on""" match = re.search(r"VALUES \((\d+),", inserts[-1]) @@ -332,7 +353,7 @@ class OracleConvertor(Convertor): def __init__(self, src): super().__init__(src, "Oracle") - def translate_type(self, type, size: None | int | Tuple[int]): + def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]): """类型转换""" type = type.lower() @@ -369,15 +390,19 @@ class OracleConvertor(Convertor): full_type = self.translate_type(type, col["size"]) nullable = "NULL" if col["nullable"] else "NOT NULL" default = f"DEFAULT {col['default']}" if col["default"] is not None else "" - return f"{'\"size\"' if name == "size" else name } {full_type} {default} {nullable}" + # Oracle 中 size 不能作为字段名 + field_name = '"size"' if name == "size" else name + # Oracle DEFAULT 定义在 NULLABLE 之前 + return f"{field_name} {full_type} {default} {nullable}" table_name = ddl["table_name"].lower() columns = [f"{generate_column(col).strip()}" for col in ddl["columns"]] + field_def_list = ",\n ".join(columns) script = f"""-- ---------------------------- -- Table structure for {table_name} -- ---------------------------- -CREATE TABLE {ddl['table_name'].lower()} ( - {',\n '.join(columns)} +CREATE TABLE {table_name} ( + {field_def_list} );""" # oracle INSERT '' 不能通过 NOT NULL 校验 @@ -385,72 +410,51 @@ CREATE TABLE {ddl['table_name'].lower()} ( return script - def gen_comment(self, table_sql, table_name) -> str: - script = "" - for line in table_sql.split("\n"): - match = re.search(r"`([^`]+)`.* COMMENT '([^']+)'", line) - if match: - script += f"COMMENT ON COLUMN {table_name}.{match.group(1)} IS '{match.group(2).replace('\\n', '\n')}';\n" + def gen_index(self, ddl: Dict) -> str: + return "\n".join(f"{script};" for script in self.index(ddl)) - match = re.search(r"COMMENT \= '([^']+)';", table_sql) - table_comment = match.group(1) if match else None + def gen_comment(self, table_sql: str, table_name: str) -> str: + script = "" + for field, comment_string in self.filed_comments(table_sql): + script += ( + f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n" + ) + + table_comment = self.table_comment(table_sql) if table_comment: - script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';" + script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n" return script - def gen_pk(self, table_name) -> str: + def gen_pk(self, table_name: str) -> str: """生成主键定义""" return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n" - def gen_index(self, table_ddl) -> str: - """生成 INDEX 定义""" + def gen_index(self, ddl: Dict) -> str: + return "\n".join(f"{script};" for script in self.index(ddl)) - def generate_columns(columns): - keys = [ - f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}" - for col in columns[0] - ] - return ", ".join(keys) - - script = "" - for no, index in enumerate(table_ddl["index"], 1): - columns = generate_columns(index["columns"]) - table_name = table_ddl["table_name"].lower() - script += ( - f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns});\n" - ) - return script - - def gen_insert(self, table_name) -> str: + def gen_insert(self, table_name: str) -> str: """拷贝 INSERT 语句""" - PREFIX = f"INSERT INTO `{table_name}`" inserts = [] - for line in self.content.split("\n"): - if line.startswith(PREFIX): - head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1) - head = head.strip().replace("`", "").lower() - tail = tail.strip().replace(r"\"", '"') - script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}" - # bit(1)数据转换 - script = script.replace("b'0'", "'0'").replace("b'1'", "'1'") - # 对日期数据添加 TO_DATE 转换 - script = re.sub( - r"('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}')", - r"to_date(\g<1>, 'SYYYY-MM-DD HH24:MI:SS')", - script, - ) - inserts.append(script) + for insert_script in Convertor.inserts(table_name, self.content): + # 对日期数据添加 TO_DATE 转换 + insert_script = re.sub( + r"('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}')", + r"to_date(\g<1>, 'SYYYY-MM-DD HH24:MI:SS')", + insert_script, + ) + inserts.append(insert_script) ## 生成 insert 脚本 script = "" last_id = 0 if inserts: + inserts_lines = "\n".join(inserts) script += f"""\n\n-- ---------------------------- -- Records of {table_name.lower()} -- ---------------------------- -- @formatter:off -{'\n'.join(inserts)} +{inserts_lines} COMMIT; -- @formatter:on""" match = re.search(r"VALUES \((\d+),", inserts[-1]) @@ -476,7 +480,7 @@ class SQLServerConvertor(Convertor): def __init__(self, src): super().__init__(src, "Microsoft SQL Server") - def translate_type(self, type, size): + def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]): """类型转换""" type = type.lower() @@ -507,7 +511,7 @@ class SQLServerConvertor(Convertor): def _generate_column(col): name = col["name"].lower() - if name == 'id': + if name == "id": return "id bigint NOT NULL PRIMARY KEY IDENTITY" if name == "deleted": return "deleted bit DEFAULT 0 NOT NULL" @@ -520,35 +524,34 @@ class SQLServerConvertor(Convertor): table_name = ddl["table_name"].lower() columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]] + filed_def_list = ",\n ".join(columns) script = f"""-- ---------------------------- -- Table structure for {table_name} -- ---------------------------- DROP TABLE IF EXISTS {table_name}; CREATE TABLE {table_name} ( - {',\n '.join(columns)} + {filed_def_list} ) GO""" return script - def gen_comment(self, table_sql, table_name) -> str: + def gen_comment(self, table_sql: str, table_name: str) -> str: """生成字段及表的注释""" script = "" - for line in table_sql.split("\n"): - match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip()) - if match: - script += f"""EXEC sp_addextendedproperty - 'MS_Description', N'{match.group(2).replace('\\n', '\n')}', + + for field, comment_string in self.filed_comments(table_sql): + script += f"""EXEC sp_addextendedproperty + 'MS_Description', N'{comment_string}', 'SCHEMA', N'dbo', 'TABLE', N'{table_name}', - 'COLUMN', N'{match.group(1)}' + 'COLUMN', N'{field}' GO """ - match = re.search(r"COMMENT \= '([^']+)';", table_sql) - table_comment = match.group(1) if match else None + table_comment = self.table_comment(table_sql) if table_comment: script += f"""EXEC sp_addextendedproperty 'MS_Description', N'{table_comment}', @@ -557,55 +560,34 @@ GO GO """ - return script - def gen_pk(self, table_name) -> str: + def gen_pk(self, table_name: str) -> str: """生成主键定义""" return "" - def gen_index(self, ddl) -> str: + def gen_index(self, ddl: Dict) -> str: """生成 index""" + return "\n".join(f"{script}\nGO" for script in self.index(ddl)) - def generate_columns(columns): - keys = [ - f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}" - for col in columns[0] - ] - return ", ".join(keys) - - script = "" - for no, index in enumerate(ddl["index"], 1): - columns = generate_columns(index["columns"]) - table_name = ddl["table_name"].lower() - script += f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns})\nGO\n" - - return script - - def gen_insert(self, table_name) -> str: + def gen_insert(self, table_name: str) -> str: """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence""" - PREFIX = f"INSERT INTO `{table_name}`" - # 收集 `table_name` 对应的 insert 语句 inserts = [] - for line in self.content.split("\n"): - if line.startswith(PREFIX): - head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1) - head = head.strip().replace("`", "").lower() - tail = tail.strip().replace(r"\"", '"') - # SQLServer: 字符串前加N,hack,是否存在替换字符串内容的风险 - tail = tail.replace(", '", ", N'").replace("VALUES ('", "VALUES (N')") - script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}" - # bit(1)数据转换 - script = script.replace("b'0'", "'0'").replace("b'1'", "'1'") - # 删除 insert 的结尾分号 - script = re.sub(";$", r"\nGO", script) - inserts.append(script) + for insert_script in Convertor.inserts(table_name, self.content): + # SQLServer: 字符串前加N,hack,是否存在替换字符串内容的风险 + insert_script = insert_script.replace(", '", ", N'").replace( + "VALUES ('", "VALUES (N')" + ) + # 删除 insert 的结尾分号 + insert_script = re.sub(";$", r"\nGO", insert_script) + inserts.append(insert_script) ## 生成 insert 脚本 script = "" if inserts: + inserts_lines = "\n".join(inserts) script += f"""\n\n-- ---------------------------- -- Records of {table_name.lower()} -- ---------------------------- @@ -614,7 +596,7 @@ BEGIN TRANSACTION GO SET IDENTITY_INSERT {table_name.lower()} ON GO -{'\n'.join(inserts)} +{inserts_lines} SET IDENTITY_INSERT {table_name.lower()} OFF GO COMMIT @@ -625,10 +607,26 @@ GO def main(): - sql_file = pathlib.Path('../mysql/ruoyi-vue-pro.sql').resolve().as_posix() - # convertor = PostgreSQLConvertor(sql_file) - # convertor = OracleConvertor(sql_file) - convertor = SQLServerConvertor(sql_file) + parser = argparse.ArgumentParser(description="芋道系统数据库转换工具") + parser.add_argument( + "type", + type=str, + help="目标数据库类型", + choices=["postgres", "oracle", "sqlserver"], + ) + args = parser.parse_args() + + sql_file = pathlib.Path("../mysql/ruoyi-vue-pro.sql").resolve().as_posix() + convertor = None + if args.type == "postgres": + convertor = PostgreSQLConvertor(sql_file) + elif args.type == "oracle": + convertor = OracleConvertor(sql_file) + elif args.type == "sqlserver": + convertor = SQLServerConvertor(sql_file) + else: + raise NotImplementedError(f"不支持目标数据库类型: {args.type}") + convertor.print()