提交 | 用户 | 时间
e7c126 1 # encoding=utf8
H 2 """平台系统数据库迁移工具
3
4 Author: dhb52 (https://gitee.com/dhb52)
5
6 pip install simple-ddl-parser
7 """
8
9 import argparse
10 import pathlib
11 import re
12 import time
13 from abc import ABC, abstractmethod
14 from typing import Dict, Generator, Optional, Tuple, Union
15
16 from simple_ddl_parser import DDLParser
17
18 PREAMBLE = """/*
19  Iailab Database Transfer Tool
20
21  Source Server Type    : MySQL
22
23  Target Server Type    : {db_type}
24
25  Date: {date}
26 */
27
28 """
29
30
31 def load_and_clean(sql_file: str) -> str:
32     """加载源 SQL 文件,并清理内容方便下一步 ddl 解析
33
34     Args:
35         sql_file (str): sql文件路径
36
37     Returns:
38         str: 清理后的sql文件内容
39     """
40     REPLACE_PAIR_LIST = (
41         (" CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci ", " "),
42         (" KEY `", " INDEX `"),
43         ("UNIQUE INDEX", "UNIQUE KEY"),
44         ("b'0'", "'0'"),
45         ("b'1'", "'1'"),
46     )
47
48     content = open(sql_file).read()
49     for replace_pair in REPLACE_PAIR_LIST:
50         content = content.replace(*replace_pair)
51     content = re.sub(r"ENGINE.*COMMENT", "COMMENT", content)
52     content = re.sub(r"ENGINE.*;", ";", content)
53     return content
54
55
56 class Convertor(ABC):
57     def __init__(self, src: str, db_type) -> None:
58         self.src = src
59         self.db_type = db_type
60         self.content = load_and_clean(self.src)
61         self.table_script_list = re.findall(r"CREATE TABLE [^;]*;", self.content)
62
63     @abstractmethod
64     def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]) -> str:
65         """字段类型转换
66
67         Args:
68             type (str): 字段类型
69             size (Optional[Union[int, Tuple[int]]]): 字段长度描述, 如varchar(255), decimal(10,2)
70
71         Returns:
72             str: 类型定义
73         """
74         pass
75
76     @abstractmethod
77     def gen_create(self, table_ddl: Dict) -> str:
78         """生成 create 脚本
79
80         Args:
81             table_ddl (Dict): 表DDL
82
83         Returns:
84             str:  生成脚本
85         """
86         pass
87
88     @abstractmethod
89     def gen_pk(self, table_name: str) -> str:
90         """生成主键定义
91
92         Args:
93             table_name (str): 表名
94
95         Returns:
96             str: 生成脚本
97         """
98         pass
99
100     @abstractmethod
101     def gen_index(self, ddl: Dict) -> str:
102         """生成索引定义
103
104         Args:
105             table_ddl (Dict): 表DDL
106
107         Returns:
108             str: 生成脚本
109         """
110         pass
111
112     @abstractmethod
113     def gen_comment(self, table_sql: str, table_name: str) -> str:
114         """生成字段/表注释
115
116         Args:
117             table_sql (str): 原始表SQL
118             table_name (str): 表名
119
120         Returns:
121             str: 生成脚本
122         """
123         pass
124
125     @abstractmethod
126     def gen_insert(self, table_name: str) -> str:
127         """生成 insert 语句块
128
129         Args:
130             table_name (str): 表名
131
132         Returns:
133             str: 生成脚本
134         """
135         pass
136
137     def gen_dual(self) -> str:
138         """生成虚拟 dual 表
139
140         Returns:
141             str: 生成脚本, 默认返回空脚本, 表示当前数据库无需手工创建
142         """
143         return ""
144
145     @staticmethod
146     def inserts(table_name: str, script_content: str) -> Generator:
147         PREFIX = f"INSERT INTO `{table_name}`"
148
149         # 收集 `table_name` 对应的 insert 语句
150         for line in script_content.split("\n"):
151             if line.startswith(PREFIX):
152                 head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
153                 head = head.strip().replace("`", "").lower()
154                 tail = tail.strip().replace(r"\"", '"')
155                 # tail = tail.replace("b'0'", "'0'").replace("b'1'", "'1'")
156                 yield f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
157
158     @staticmethod
159     def index(ddl: Dict) -> Generator:
160         """生成索引定义
161
162         Args:
163             ddl (Dict): 表DDL
164
165         Yields:
166             Generator[str]: create index 语句
167         """
168
169         def generate_columns(columns):
170             keys = [
171                 f"{col['name'].lower()}{' ' + col['order'].lower() if col['order'] != 'ASC' else ''}"
172                 for col in columns[0]
173             ]
174             return ", ".join(keys)
175
176         for no, index in enumerate(ddl["index"], 1):
177             columns = generate_columns(index["columns"])
178             table_name = ddl["table_name"].lower()
179             yield f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns})"
180
181     @staticmethod
182     def filed_comments(table_sql: str) -> Generator:
183         for line in table_sql.split("\n"):
184             match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip())
185             if match:
186                 field = match.group(1)
187                 comment_string = match.group(2).replace("\\n", "\n")
188                 yield field, comment_string
189
190     def table_comment(self, table_sql: str) -> str:
191         match = re.search(r"COMMENT \= '([^']+)';", table_sql)
192         return match.group(1) if match else None
193
194     def print(self):
195         """打印转换后的sql脚本到终端"""
196         print(
197             PREAMBLE.format(
198                 db_type=self.db_type,
199                 date=time.strftime("%Y-%m-%d %H:%M:%S"),
200             )
201         )
202
203         dual = self.gen_dual()
204         if dual:
205             print(
206                 f"""-- ----------------------------
207 -- Table structure for dual
208 -- ----------------------------
209 {dual}
210
211 """
212             )
213
214         error_scripts = []
215         for table_sql in self.table_script_list:
216             ddl = DDLParser(table_sql.replace("`", "")).run()
217
218             # 如果parse失败, 需要跟进
219             if len(ddl) == 0:
220                 error_scripts.append(table_sql)
221                 continue
222
223             table_ddl = ddl[0]
224             table_name = table_ddl["table_name"]
225
226             # 忽略 quartz 的内容
227             if table_name.lower().startswith("qrtz"):
228                 continue
229
230             # 为每个表生成个5个基本部分
231             create = self.gen_create(table_ddl)
232             pk = self.gen_pk(table_name)
233             index = self.gen_index(table_ddl)
234             comment = self.gen_comment(table_sql, table_name)
235             inserts = self.gen_insert(table_name)
236
237             # 组合当前表的DDL脚本
238             script = f"""{create}
239
240 {pk}
241
242 {index}
243
244 {comment}
245
246 {inserts}
247 """
248
249             # 清理
250             script = re.sub("\n{3,}", "\n\n", script).strip() + "\n"
251
252             print(script)
253
254         # 将parse失败的脚本打印出来
255         if error_scripts:
256             for script in error_scripts:
257                 print(script)
258
259
260 class PostgreSQLConvertor(Convertor):
261     def __init__(self, src):
262         super().__init__(src, "PostgreSQL")
263
264     def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
265         """类型转换"""
266
267         type = type.lower()
268
269         if type == "varchar":
270             return f"varchar({size})"
271         if type == "int":
272             return "int4"
273         if type == "bigint" or type == "bigint unsigned":
274             return "int8"
275         if type == "datetime":
276             return "timestamp"
277         if type == "bit":
278             return "bool"
279         if type in ("tinyint", "smallint"):
280             return "int2"
281         if type == "text":
282             return "text"
283         if type in ("blob", "mediumblob"):
284             return "bytea"
285         if type == "decimal":
286             return (
287                 f"numeric({','.join(str(s) for s in size)})" if len(size) else "numeric"
288             )
289
290     def gen_create(self, ddl: Dict) -> str:
291         """生成 create"""
292
293         def _generate_column(col):
294             name = col["name"].lower()
295             if name == "deleted":
296                 return "deleted int2 NOT NULL DEFAULT 0"
297
298             type = col["type"].lower()
299             full_type = self.translate_type(type, col["size"])
300             nullable = "NULL" if col["nullable"] else "NOT NULL"
301             default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
302             return f"{name} {full_type} {nullable} {default}"
303
304         table_name = ddl["table_name"].lower()
305         columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]]
306         filed_def_list = ",\n  ".join(columns)
307         script = f"""-- ----------------------------
308 -- Table structure for {table_name}
309 -- ----------------------------
310 DROP TABLE IF EXISTS {table_name};
311 CREATE TABLE {table_name} (
312     {filed_def_list}
313 );"""
314
315         return script
316
317     def gen_index(self, ddl: Dict) -> str:
318         return "\n".join(f"{script};" for script in self.index(ddl))
319
320     def gen_comment(self, table_sql: str, table_name: str) -> str:
321         """生成字段及表的注释"""
322
323         script = ""
324         for field, comment_string in self.filed_comments(table_sql):
325             script += (
326                 f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n"
327             )
328
329         table_comment = self.table_comment(table_sql)
330         if table_comment:
331             script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n"
332
333         return script
334
335     def gen_pk(self, table_name) -> str:
336         """生成主键定义"""
337         return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n"
338
339     def gen_insert(self, table_name: str) -> str:
340         """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence"""
341
342         inserts = list(Convertor.inserts(table_name, self.content))
343         ## 生成 insert 脚本
344         script = ""
345         last_id = 0
346         if inserts:
347             inserts_lines = "\n".join(inserts)
348             script += f"""\n\n-- ----------------------------
349 -- Records of {table_name.lower()}
350 -- ----------------------------
351 -- @formatter:off
352 BEGIN;
353 {inserts_lines}
354 COMMIT;
355 -- @formatter:on"""
356             match = re.search(r"VALUES \((\d+),", inserts[-1])
357             if match:
358                 last_id = int(match.group(1))
359
360         # 生成 Sequence
361         script += (
362             "\n\n"
363             + f"""DROP SEQUENCE IF EXISTS {table_name}_seq;
364 CREATE SEQUENCE {table_name}_seq
365     START {last_id + 1};"""
366         )
367
368         return script
369
370     def gen_dual(self) -> str:
371         return """DROP TABLE IF EXISTS dual;
372 CREATE TABLE dual
373 (
374 );"""
375
376
377 class OracleConvertor(Convertor):
378     def __init__(self, src):
379         super().__init__(src, "Oracle")
380
381     def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
382         """类型转换"""
383         type = type.lower()
384
385         if type == "varchar":
386             return f"varchar2({size if size < 4000 else 4000})"
387         if type == "int":
388             return "number"
389         if type == "bigint" or type == "bigint unsigned":
390             return "number"
391         if type == "datetime":
392             return "date"
393         if type == "bit":
394             return "number(1,0)"
395         if type in ("tinyint", "smallint"):
396             return "smallint"
397         if type == "text":
398             return "clob"
399         if type in ("blob", "mediumblob"):
400             return "blob"
401         if type == "decimal":
402             return (
403                 f"number({','.join(str(s) for s in size)})" if len(size) else "number"
404             )
405
406     def gen_create(self, ddl) -> str:
407         """生成 CREATE 语句"""
408
409         def generate_column(col):
410             name = col["name"].lower()
411             if name == "deleted":
412                 return "deleted number(1,0) DEFAULT 0 NOT NULL"
413
414             type = col["type"].lower()
415             full_type = self.translate_type(type, col["size"])
416             nullable = "NULL" if col["nullable"] else "NOT NULL"
417             default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
418             # Oracle 中 size 不能作为字段名
419             field_name = '"size"' if name == "size" else name
420             # Oracle DEFAULT 定义在 NULLABLE 之前
421             return f"{field_name} {full_type} {default} {nullable}"
422
423         table_name = ddl["table_name"].lower()
424         columns = [f"{generate_column(col).strip()}" for col in ddl["columns"]]
425         field_def_list = ",\n    ".join(columns)
426         script = f"""-- ----------------------------
427 -- Table structure for {table_name}
428 -- ----------------------------
429 CREATE TABLE {table_name} (
430     {field_def_list}
431 );"""
432
433         # oracle INSERT '' 不能通过 NOT NULL 校验
434         script = script.replace("DEFAULT '' NOT NULL", "DEFAULT '' NULL")
435
436         return script
437
438     def gen_index(self, ddl: Dict) -> str:
439         return "\n".join(f"{script};" for script in self.index(ddl))
440
441     def gen_comment(self, table_sql: str, table_name: str) -> str:
442         script = ""
443         for field, comment_string in self.filed_comments(table_sql):
444             script += (
445                 f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n"
446             )
447
448         table_comment = self.table_comment(table_sql)
449         if table_comment:
450             script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n"
451
452         return script
453
454     def gen_pk(self, table_name: str) -> str:
455         """生成主键定义"""
456         return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n"
457
458     def gen_index(self, ddl: Dict) -> str:
459         return "\n".join(f"{script};" for script in self.index(ddl))
460
461     def gen_insert(self, table_name: str) -> str:
462         """拷贝 INSERT 语句"""
463         inserts = []
464         for insert_script in Convertor.inserts(table_name, self.content):
465             # 对日期数据添加 TO_DATE 转换
466             insert_script = re.sub(
467                 r"('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}')",
468                 r"to_date(\g<1>, 'SYYYY-MM-DD HH24:MI:SS')",
469                 insert_script,
470             )
471             inserts.append(insert_script)
472
473         ## 生成 insert 脚本
474         script = ""
475         last_id = 0
476         if inserts:
477             inserts_lines = "\n".join(inserts)
478             script += f"""\n\n-- ----------------------------
479 -- Records of {table_name.lower()}
480 -- ----------------------------
481 -- @formatter:off
482 {inserts_lines}
483 COMMIT;
484 -- @formatter:on"""
485             match = re.search(r"VALUES \((\d+),", inserts[-1])
486             if match:
487                 last_id = int(match.group(1))
488
489         # 生成 Sequence
490         script += f"""
491
492 CREATE SEQUENCE {table_name}_seq
493     START WITH {last_id + 1};"""
494
495         return script
496
497
498 class SQLServerConvertor(Convertor):
499     """_summary_
500
501     Args:
502         Convertor (_type_): _description_
503     """
504
505     def __init__(self, src):
506         super().__init__(src, "Microsoft SQL Server")
507
508     def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
509         """类型转换"""
510
511         type = type.lower()
512
513         if type == "varchar":
514             return f"nvarchar({size if size < 4000 else 4000})"
515         if type == "int":
516             return "int"
517         if type == "bigint" or type == "bigint unsigned":
518             return "bigint"
519         if type == "datetime":
520             return "datetime2"
521         if type == "bit":
522             return "varchar(1)"
523         if type in ("tinyint", "smallint"):
524             return "tinyint"
525         if type == "text":
526             return "nvarchar(max)"
527         if type in ("blob", "mediumblob"):
528             return "varbinary(max)"
529         if type == "decimal":
530             return (
531                 f"numeric({','.join(str(s) for s in size)})" if len(size) else "numeric"
532             )
533
534     def gen_create(self, ddl: Dict) -> str:
535         """生成 create"""
536
537         def _generate_column(col):
538             name = col["name"].lower()
539             if name == "id":
540                 return "id bigint NOT NULL PRIMARY KEY IDENTITY"
541             if name == "deleted":
542                 return "deleted bit DEFAULT 0 NOT NULL"
543
544             type = col["type"].lower()
545             full_type = self.translate_type(type, col["size"])
546             nullable = "NULL" if col["nullable"] else "NOT NULL"
547             default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
548             return f"{name} {full_type} {default} {nullable}"
549
550         table_name = ddl["table_name"].lower()
551         columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]]
552         filed_def_list = ",\n    ".join(columns)
553         script = f"""-- ----------------------------
554 -- Table structure for {table_name}
555 -- ----------------------------
556 DROP TABLE IF EXISTS {table_name};
557 CREATE TABLE {table_name} (
558     {filed_def_list}
559 )
560 GO"""
561
562         return script
563
564     def gen_comment(self, table_sql: str, table_name: str) -> str:
565         """生成字段及表的注释"""
566
567         script = ""
568
569         for field, comment_string in self.filed_comments(table_sql):
570             script += f"""EXEC sp_addextendedproperty
571     'MS_Description', N'{comment_string}',
572     'SCHEMA', N'dbo',
573     'TABLE', N'{table_name}',
574     'COLUMN', N'{field}'
575 GO
576
577 """
578
579         table_comment = self.table_comment(table_sql)
580         if table_comment:
581             script += f"""EXEC sp_addextendedproperty
582     'MS_Description', N'{table_comment}',
583     'SCHEMA', N'dbo',
584     'TABLE', N'{table_name}'
585 GO
586
587 """
588         return script
589
590     def gen_pk(self, table_name: str) -> str:
591         """生成主键定义"""
592         return ""
593
594     def gen_index(self, ddl: Dict) -> str:
595         """生成 index"""
596         return "\n".join(f"{script}\nGO" for script in self.index(ddl))
597
598     def gen_insert(self, table_name: str) -> str:
599         """生成 insert 语句"""
600
601         # 收集 `table_name` 对应的 insert 语句
602         inserts = []
603         for insert_script in Convertor.inserts(table_name, self.content):
604             # SQLServer: 字符串前加N,hack,是否存在替换字符串内容的风险
605             insert_script = insert_script.replace(", '", ", N'").replace(
606                 "VALUES ('", "VALUES (N')"
607             )
608             # 删除 insert 的结尾分号
609             insert_script = re.sub(";$", r"\nGO", insert_script)
610             inserts.append(insert_script)
611
612         ## 生成 insert 脚本
613         script = ""
614         if inserts:
615             inserts_lines = "\n".join(inserts)
616             script += f"""\n\n-- ----------------------------
617 -- Records of {table_name.lower()}
618 -- ----------------------------
619 -- @formatter:off
620 BEGIN TRANSACTION
621 GO
622 SET IDENTITY_INSERT {table_name.lower()} ON
623 GO
624 {inserts_lines}
625 SET IDENTITY_INSERT {table_name.lower()} OFF
626 GO
627 COMMIT
628 GO
629 -- @formatter:on"""
630
631         return script
632
633     def gen_dual(self) -> str:
634         return """DROP TABLE IF EXISTS dual
635 GO
636
637 CREATE TABLE dual
638 (
639   id int NULL
640 )
641 GO
642
643 EXEC sp_addextendedproperty
644     'MS_Description', N'数据库连接的表',
645     'SCHEMA', N'dbo',
646     'TABLE', N'dual'
647 GO"""
648
649
650 class DM8Convertor(Convertor):
651     def __init__(self, src):
652         super().__init__(src, "DM8")
653
654     def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
655         """类型转换"""
656         type = type.lower()
657
658         if type == "varchar":
659             return f"varchar({size})"
660         if type == "int":
661             return "int"
662         if type == "bigint" or type == "bigint unsigned":
663             return "bigint"
664         if type == "datetime":
665             return "datetime"
666         if type == "bit":
667             return "bit"
668         if type in ("tinyint", "smallint"):
669             return "smallint"
670         if type == "text":
671             return "text"
672         if type == "blob":
673             return "blob"
674         if type == "mediumblob":
675             return "varchar(10240)"
676         if type == "decimal":
677             return (
678                 f"decimal({','.join(str(s) for s in size)})" if len(size) else "decimal"
679             )
680
681     def gen_create(self, ddl) -> str:
682         """生成 CREATE 语句"""
683
684         def generate_column(col):
685             name = col["name"].lower()
686             if name == "id":
687                 return "id bigint NOT NULL PRIMARY KEY IDENTITY"
688
689             type = col["type"].lower()
690             full_type = self.translate_type(type, col["size"])
691             nullable = "NULL" if col["nullable"] else "NOT NULL"
692             default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
693             return f"{name} {full_type} {default} {nullable}"
694
695         table_name = ddl["table_name"].lower()
696         columns = [f"{generate_column(col).strip()}" for col in ddl["columns"]]
697         field_def_list = ",\n    ".join(columns)
698         script = f"""-- ----------------------------
699 -- Table structure for {table_name}
700 -- ----------------------------
701 CREATE TABLE {table_name} (
702     {field_def_list}
703 );"""
704
705         # oracle INSERT '' 不能通过 NOT NULL 校验
706         script = script.replace("DEFAULT '' NOT NULL", "DEFAULT '' NULL")
707
708         return script
709
710     def gen_index(self, ddl: Dict) -> str:
711         return "\n".join(f"{script};" for script in self.index(ddl))
712
713     def gen_comment(self, table_sql: str, table_name: str) -> str:
714         script = ""
715         for field, comment_string in self.filed_comments(table_sql):
716             script += (
717                 f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n"
718             )
719
720         table_comment = self.table_comment(table_sql)
721         if table_comment:
722             script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n"
723
724         return script
725
726     def gen_pk(self, table_name: str) -> str:
727         """生成主键定义"""
728         return ""
729
730     def gen_index(self, ddl: Dict) -> str:
731         return "\n".join(f"{script};" for script in self.index(ddl))
732
733     def gen_insert(self, table_name: str) -> str:
734         """拷贝 INSERT 语句"""
735         inserts = list(Convertor.inserts(table_name, self.content))
736
737         ## 生成 insert 脚本
738         script = ""
739         if inserts:
740             inserts_lines = "\n".join(inserts)
741             script += f"""\n\n-- ----------------------------
742 -- Records of {table_name.lower()}
743 -- ----------------------------
744 -- @formatter:off
745 SET IDENTITY_INSERT {table_name.lower()} ON;
746 {inserts_lines}
747 COMMIT;
748 SET IDENTITY_INSERT {table_name.lower()} OFF;
749 -- @formatter:on"""
750
751         return script
752
753
754 def main():
755     parser = argparse.ArgumentParser(description="平台系统数据库转换工具")
756     parser.add_argument(
757         "type",
758         type=str,
759         help="目标数据库类型",
760         choices=["postgres", "oracle", "sqlserver", "dm8"],
761     )
762     args = parser.parse_args()
763
764     sql_file = pathlib.Path("../mysql/ruoyi-vue-pro.sql").resolve().as_posix()
765     convertor = None
766     if args.type == "postgres":
767         convertor = PostgreSQLConvertor(sql_file)
768     elif args.type == "oracle":
769         convertor = OracleConvertor(sql_file)
770     elif args.type == "sqlserver":
771         convertor = SQLServerConvertor(sql_file)
772     elif args.type == "dm8":
773         convertor = DM8Convertor(sql_file)
774     else:
775         raise NotImplementedError(f"不支持目标数据库类型: {args.type}")
776
777     convertor.print()
778
779
780 if __name__ == "__main__":
781     main()