提交 | 用户 | 时间
|
e7c126
|
1 |
# encoding=utf8 |
H |
2 |
"""平台系统数据库迁移工具 |
|
3 |
|
|
4 |
Author: dhb52 (https://gitee.com/dhb52) |
|
5 |
|
|
6 |
pip install simple-ddl-parser |
|
7 |
""" |
|
8 |
|
|
9 |
import argparse |
|
10 |
import pathlib |
|
11 |
import re |
|
12 |
import time |
|
13 |
from abc import ABC, abstractmethod |
|
14 |
from typing import Dict, Generator, Optional, Tuple, Union |
|
15 |
|
|
16 |
from simple_ddl_parser import DDLParser |
|
17 |
|
|
18 |
PREAMBLE = """/* |
|
19 |
Iailab Database Transfer Tool |
|
20 |
|
|
21 |
Source Server Type : MySQL |
|
22 |
|
|
23 |
Target Server Type : {db_type} |
|
24 |
|
|
25 |
Date: {date} |
|
26 |
*/ |
|
27 |
|
|
28 |
""" |
|
29 |
|
|
30 |
|
|
31 |
def load_and_clean(sql_file: str) -> str: |
|
32 |
"""加载源 SQL 文件,并清理内容方便下一步 ddl 解析 |
|
33 |
|
|
34 |
Args: |
|
35 |
sql_file (str): sql文件路径 |
|
36 |
|
|
37 |
Returns: |
|
38 |
str: 清理后的sql文件内容 |
|
39 |
""" |
|
40 |
REPLACE_PAIR_LIST = ( |
|
41 |
(" CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci ", " "), |
|
42 |
(" KEY `", " INDEX `"), |
|
43 |
("UNIQUE INDEX", "UNIQUE KEY"), |
|
44 |
("b'0'", "'0'"), |
|
45 |
("b'1'", "'1'"), |
|
46 |
) |
|
47 |
|
|
48 |
content = open(sql_file).read() |
|
49 |
for replace_pair in REPLACE_PAIR_LIST: |
|
50 |
content = content.replace(*replace_pair) |
|
51 |
content = re.sub(r"ENGINE.*COMMENT", "COMMENT", content) |
|
52 |
content = re.sub(r"ENGINE.*;", ";", content) |
|
53 |
return content |
|
54 |
|
|
55 |
|
|
56 |
class Convertor(ABC): |
|
57 |
def __init__(self, src: str, db_type) -> None: |
|
58 |
self.src = src |
|
59 |
self.db_type = db_type |
|
60 |
self.content = load_and_clean(self.src) |
|
61 |
self.table_script_list = re.findall(r"CREATE TABLE [^;]*;", self.content) |
|
62 |
|
|
63 |
@abstractmethod |
|
64 |
def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]) -> str: |
|
65 |
"""字段类型转换 |
|
66 |
|
|
67 |
Args: |
|
68 |
type (str): 字段类型 |
|
69 |
size (Optional[Union[int, Tuple[int]]]): 字段长度描述, 如varchar(255), decimal(10,2) |
|
70 |
|
|
71 |
Returns: |
|
72 |
str: 类型定义 |
|
73 |
""" |
|
74 |
pass |
|
75 |
|
|
76 |
@abstractmethod |
|
77 |
def gen_create(self, table_ddl: Dict) -> str: |
|
78 |
"""生成 create 脚本 |
|
79 |
|
|
80 |
Args: |
|
81 |
table_ddl (Dict): 表DDL |
|
82 |
|
|
83 |
Returns: |
|
84 |
str: 生成脚本 |
|
85 |
""" |
|
86 |
pass |
|
87 |
|
|
88 |
@abstractmethod |
|
89 |
def gen_pk(self, table_name: str) -> str: |
|
90 |
"""生成主键定义 |
|
91 |
|
|
92 |
Args: |
|
93 |
table_name (str): 表名 |
|
94 |
|
|
95 |
Returns: |
|
96 |
str: 生成脚本 |
|
97 |
""" |
|
98 |
pass |
|
99 |
|
|
100 |
@abstractmethod |
|
101 |
def gen_index(self, ddl: Dict) -> str: |
|
102 |
"""生成索引定义 |
|
103 |
|
|
104 |
Args: |
|
105 |
table_ddl (Dict): 表DDL |
|
106 |
|
|
107 |
Returns: |
|
108 |
str: 生成脚本 |
|
109 |
""" |
|
110 |
pass |
|
111 |
|
|
112 |
@abstractmethod |
|
113 |
def gen_comment(self, table_sql: str, table_name: str) -> str: |
|
114 |
"""生成字段/表注释 |
|
115 |
|
|
116 |
Args: |
|
117 |
table_sql (str): 原始表SQL |
|
118 |
table_name (str): 表名 |
|
119 |
|
|
120 |
Returns: |
|
121 |
str: 生成脚本 |
|
122 |
""" |
|
123 |
pass |
|
124 |
|
|
125 |
@abstractmethod |
|
126 |
def gen_insert(self, table_name: str) -> str: |
|
127 |
"""生成 insert 语句块 |
|
128 |
|
|
129 |
Args: |
|
130 |
table_name (str): 表名 |
|
131 |
|
|
132 |
Returns: |
|
133 |
str: 生成脚本 |
|
134 |
""" |
|
135 |
pass |
|
136 |
|
|
137 |
def gen_dual(self) -> str: |
|
138 |
"""生成虚拟 dual 表 |
|
139 |
|
|
140 |
Returns: |
|
141 |
str: 生成脚本, 默认返回空脚本, 表示当前数据库无需手工创建 |
|
142 |
""" |
|
143 |
return "" |
|
144 |
|
|
145 |
@staticmethod |
|
146 |
def inserts(table_name: str, script_content: str) -> Generator: |
|
147 |
PREFIX = f"INSERT INTO `{table_name}`" |
|
148 |
|
|
149 |
# 收集 `table_name` 对应的 insert 语句 |
|
150 |
for line in script_content.split("\n"): |
|
151 |
if line.startswith(PREFIX): |
|
152 |
head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1) |
|
153 |
head = head.strip().replace("`", "").lower() |
|
154 |
tail = tail.strip().replace(r"\"", '"') |
|
155 |
# tail = tail.replace("b'0'", "'0'").replace("b'1'", "'1'") |
|
156 |
yield f"INSERT INTO {table_name.lower()} {head} VALUES {tail}" |
|
157 |
|
|
158 |
@staticmethod |
|
159 |
def index(ddl: Dict) -> Generator: |
|
160 |
"""生成索引定义 |
|
161 |
|
|
162 |
Args: |
|
163 |
ddl (Dict): 表DDL |
|
164 |
|
|
165 |
Yields: |
|
166 |
Generator[str]: create index 语句 |
|
167 |
""" |
|
168 |
|
|
169 |
def generate_columns(columns): |
|
170 |
keys = [ |
|
171 |
f"{col['name'].lower()}{' ' + col['order'].lower() if col['order'] != 'ASC' else ''}" |
|
172 |
for col in columns[0] |
|
173 |
] |
|
174 |
return ", ".join(keys) |
|
175 |
|
|
176 |
for no, index in enumerate(ddl["index"], 1): |
|
177 |
columns = generate_columns(index["columns"]) |
|
178 |
table_name = ddl["table_name"].lower() |
|
179 |
yield f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns})" |
|
180 |
|
|
181 |
@staticmethod |
|
182 |
def filed_comments(table_sql: str) -> Generator: |
|
183 |
for line in table_sql.split("\n"): |
|
184 |
match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip()) |
|
185 |
if match: |
|
186 |
field = match.group(1) |
|
187 |
comment_string = match.group(2).replace("\\n", "\n") |
|
188 |
yield field, comment_string |
|
189 |
|
|
190 |
def table_comment(self, table_sql: str) -> str: |
|
191 |
match = re.search(r"COMMENT \= '([^']+)';", table_sql) |
|
192 |
return match.group(1) if match else None |
|
193 |
|
|
194 |
def print(self): |
|
195 |
"""打印转换后的sql脚本到终端""" |
|
196 |
print( |
|
197 |
PREAMBLE.format( |
|
198 |
db_type=self.db_type, |
|
199 |
date=time.strftime("%Y-%m-%d %H:%M:%S"), |
|
200 |
) |
|
201 |
) |
|
202 |
|
|
203 |
dual = self.gen_dual() |
|
204 |
if dual: |
|
205 |
print( |
|
206 |
f"""-- ---------------------------- |
|
207 |
-- Table structure for dual |
|
208 |
-- ---------------------------- |
|
209 |
{dual} |
|
210 |
|
|
211 |
""" |
|
212 |
) |
|
213 |
|
|
214 |
error_scripts = [] |
|
215 |
for table_sql in self.table_script_list: |
|
216 |
ddl = DDLParser(table_sql.replace("`", "")).run() |
|
217 |
|
|
218 |
# 如果parse失败, 需要跟进 |
|
219 |
if len(ddl) == 0: |
|
220 |
error_scripts.append(table_sql) |
|
221 |
continue |
|
222 |
|
|
223 |
table_ddl = ddl[0] |
|
224 |
table_name = table_ddl["table_name"] |
|
225 |
|
|
226 |
# 忽略 quartz 的内容 |
|
227 |
if table_name.lower().startswith("qrtz"): |
|
228 |
continue |
|
229 |
|
|
230 |
# 为每个表生成个5个基本部分 |
|
231 |
create = self.gen_create(table_ddl) |
|
232 |
pk = self.gen_pk(table_name) |
|
233 |
index = self.gen_index(table_ddl) |
|
234 |
comment = self.gen_comment(table_sql, table_name) |
|
235 |
inserts = self.gen_insert(table_name) |
|
236 |
|
|
237 |
# 组合当前表的DDL脚本 |
|
238 |
script = f"""{create} |
|
239 |
|
|
240 |
{pk} |
|
241 |
|
|
242 |
{index} |
|
243 |
|
|
244 |
{comment} |
|
245 |
|
|
246 |
{inserts} |
|
247 |
""" |
|
248 |
|
|
249 |
# 清理 |
|
250 |
script = re.sub("\n{3,}", "\n\n", script).strip() + "\n" |
|
251 |
|
|
252 |
print(script) |
|
253 |
|
|
254 |
# 将parse失败的脚本打印出来 |
|
255 |
if error_scripts: |
|
256 |
for script in error_scripts: |
|
257 |
print(script) |
|
258 |
|
|
259 |
|
|
260 |
class PostgreSQLConvertor(Convertor): |
|
261 |
def __init__(self, src): |
|
262 |
super().__init__(src, "PostgreSQL") |
|
263 |
|
|
264 |
def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]): |
|
265 |
"""类型转换""" |
|
266 |
|
|
267 |
type = type.lower() |
|
268 |
|
|
269 |
if type == "varchar": |
|
270 |
return f"varchar({size})" |
|
271 |
if type == "int": |
|
272 |
return "int4" |
|
273 |
if type == "bigint" or type == "bigint unsigned": |
|
274 |
return "int8" |
|
275 |
if type == "datetime": |
|
276 |
return "timestamp" |
|
277 |
if type == "bit": |
|
278 |
return "bool" |
|
279 |
if type in ("tinyint", "smallint"): |
|
280 |
return "int2" |
|
281 |
if type == "text": |
|
282 |
return "text" |
|
283 |
if type in ("blob", "mediumblob"): |
|
284 |
return "bytea" |
|
285 |
if type == "decimal": |
|
286 |
return ( |
|
287 |
f"numeric({','.join(str(s) for s in size)})" if len(size) else "numeric" |
|
288 |
) |
|
289 |
|
|
290 |
def gen_create(self, ddl: Dict) -> str: |
|
291 |
"""生成 create""" |
|
292 |
|
|
293 |
def _generate_column(col): |
|
294 |
name = col["name"].lower() |
|
295 |
if name == "deleted": |
|
296 |
return "deleted int2 NOT NULL DEFAULT 0" |
|
297 |
|
|
298 |
type = col["type"].lower() |
|
299 |
full_type = self.translate_type(type, col["size"]) |
|
300 |
nullable = "NULL" if col["nullable"] else "NOT NULL" |
|
301 |
default = f"DEFAULT {col['default']}" if col["default"] is not None else "" |
|
302 |
return f"{name} {full_type} {nullable} {default}" |
|
303 |
|
|
304 |
table_name = ddl["table_name"].lower() |
|
305 |
columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]] |
|
306 |
filed_def_list = ",\n ".join(columns) |
|
307 |
script = f"""-- ---------------------------- |
|
308 |
-- Table structure for {table_name} |
|
309 |
-- ---------------------------- |
|
310 |
DROP TABLE IF EXISTS {table_name}; |
|
311 |
CREATE TABLE {table_name} ( |
|
312 |
{filed_def_list} |
|
313 |
);""" |
|
314 |
|
|
315 |
return script |
|
316 |
|
|
317 |
def gen_index(self, ddl: Dict) -> str: |
|
318 |
return "\n".join(f"{script};" for script in self.index(ddl)) |
|
319 |
|
|
320 |
def gen_comment(self, table_sql: str, table_name: str) -> str: |
|
321 |
"""生成字段及表的注释""" |
|
322 |
|
|
323 |
script = "" |
|
324 |
for field, comment_string in self.filed_comments(table_sql): |
|
325 |
script += ( |
|
326 |
f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n" |
|
327 |
) |
|
328 |
|
|
329 |
table_comment = self.table_comment(table_sql) |
|
330 |
if table_comment: |
|
331 |
script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n" |
|
332 |
|
|
333 |
return script |
|
334 |
|
|
335 |
def gen_pk(self, table_name) -> str: |
|
336 |
"""生成主键定义""" |
|
337 |
return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n" |
|
338 |
|
|
339 |
def gen_insert(self, table_name: str) -> str: |
|
340 |
"""生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence""" |
|
341 |
|
|
342 |
inserts = list(Convertor.inserts(table_name, self.content)) |
|
343 |
## 生成 insert 脚本 |
|
344 |
script = "" |
|
345 |
last_id = 0 |
|
346 |
if inserts: |
|
347 |
inserts_lines = "\n".join(inserts) |
|
348 |
script += f"""\n\n-- ---------------------------- |
|
349 |
-- Records of {table_name.lower()} |
|
350 |
-- ---------------------------- |
|
351 |
-- @formatter:off |
|
352 |
BEGIN; |
|
353 |
{inserts_lines} |
|
354 |
COMMIT; |
|
355 |
-- @formatter:on""" |
|
356 |
match = re.search(r"VALUES \((\d+),", inserts[-1]) |
|
357 |
if match: |
|
358 |
last_id = int(match.group(1)) |
|
359 |
|
|
360 |
# 生成 Sequence |
|
361 |
script += ( |
|
362 |
"\n\n" |
|
363 |
+ f"""DROP SEQUENCE IF EXISTS {table_name}_seq; |
|
364 |
CREATE SEQUENCE {table_name}_seq |
|
365 |
START {last_id + 1};""" |
|
366 |
) |
|
367 |
|
|
368 |
return script |
|
369 |
|
|
370 |
def gen_dual(self) -> str: |
|
371 |
return """DROP TABLE IF EXISTS dual; |
|
372 |
CREATE TABLE dual |
|
373 |
( |
|
374 |
);""" |
|
375 |
|
|
376 |
|
|
377 |
class OracleConvertor(Convertor): |
|
378 |
def __init__(self, src): |
|
379 |
super().__init__(src, "Oracle") |
|
380 |
|
|
381 |
def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]): |
|
382 |
"""类型转换""" |
|
383 |
type = type.lower() |
|
384 |
|
|
385 |
if type == "varchar": |
|
386 |
return f"varchar2({size if size < 4000 else 4000})" |
|
387 |
if type == "int": |
|
388 |
return "number" |
|
389 |
if type == "bigint" or type == "bigint unsigned": |
|
390 |
return "number" |
|
391 |
if type == "datetime": |
|
392 |
return "date" |
|
393 |
if type == "bit": |
|
394 |
return "number(1,0)" |
|
395 |
if type in ("tinyint", "smallint"): |
|
396 |
return "smallint" |
|
397 |
if type == "text": |
|
398 |
return "clob" |
|
399 |
if type in ("blob", "mediumblob"): |
|
400 |
return "blob" |
|
401 |
if type == "decimal": |
|
402 |
return ( |
|
403 |
f"number({','.join(str(s) for s in size)})" if len(size) else "number" |
|
404 |
) |
|
405 |
|
|
406 |
def gen_create(self, ddl) -> str: |
|
407 |
"""生成 CREATE 语句""" |
|
408 |
|
|
409 |
def generate_column(col): |
|
410 |
name = col["name"].lower() |
|
411 |
if name == "deleted": |
|
412 |
return "deleted number(1,0) DEFAULT 0 NOT NULL" |
|
413 |
|
|
414 |
type = col["type"].lower() |
|
415 |
full_type = self.translate_type(type, col["size"]) |
|
416 |
nullable = "NULL" if col["nullable"] else "NOT NULL" |
|
417 |
default = f"DEFAULT {col['default']}" if col["default"] is not None else "" |
|
418 |
# Oracle 中 size 不能作为字段名 |
|
419 |
field_name = '"size"' if name == "size" else name |
|
420 |
# Oracle DEFAULT 定义在 NULLABLE 之前 |
|
421 |
return f"{field_name} {full_type} {default} {nullable}" |
|
422 |
|
|
423 |
table_name = ddl["table_name"].lower() |
|
424 |
columns = [f"{generate_column(col).strip()}" for col in ddl["columns"]] |
|
425 |
field_def_list = ",\n ".join(columns) |
|
426 |
script = f"""-- ---------------------------- |
|
427 |
-- Table structure for {table_name} |
|
428 |
-- ---------------------------- |
|
429 |
CREATE TABLE {table_name} ( |
|
430 |
{field_def_list} |
|
431 |
);""" |
|
432 |
|
|
433 |
# oracle INSERT '' 不能通过 NOT NULL 校验 |
|
434 |
script = script.replace("DEFAULT '' NOT NULL", "DEFAULT '' NULL") |
|
435 |
|
|
436 |
return script |
|
437 |
|
|
438 |
def gen_index(self, ddl: Dict) -> str: |
|
439 |
return "\n".join(f"{script};" for script in self.index(ddl)) |
|
440 |
|
|
441 |
def gen_comment(self, table_sql: str, table_name: str) -> str: |
|
442 |
script = "" |
|
443 |
for field, comment_string in self.filed_comments(table_sql): |
|
444 |
script += ( |
|
445 |
f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n" |
|
446 |
) |
|
447 |
|
|
448 |
table_comment = self.table_comment(table_sql) |
|
449 |
if table_comment: |
|
450 |
script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n" |
|
451 |
|
|
452 |
return script |
|
453 |
|
|
454 |
def gen_pk(self, table_name: str) -> str: |
|
455 |
"""生成主键定义""" |
|
456 |
return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n" |
|
457 |
|
|
458 |
def gen_index(self, ddl: Dict) -> str: |
|
459 |
return "\n".join(f"{script};" for script in self.index(ddl)) |
|
460 |
|
|
461 |
def gen_insert(self, table_name: str) -> str: |
|
462 |
"""拷贝 INSERT 语句""" |
|
463 |
inserts = [] |
|
464 |
for insert_script in Convertor.inserts(table_name, self.content): |
|
465 |
# 对日期数据添加 TO_DATE 转换 |
|
466 |
insert_script = re.sub( |
|
467 |
r"('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}')", |
|
468 |
r"to_date(\g<1>, 'SYYYY-MM-DD HH24:MI:SS')", |
|
469 |
insert_script, |
|
470 |
) |
|
471 |
inserts.append(insert_script) |
|
472 |
|
|
473 |
## 生成 insert 脚本 |
|
474 |
script = "" |
|
475 |
last_id = 0 |
|
476 |
if inserts: |
|
477 |
inserts_lines = "\n".join(inserts) |
|
478 |
script += f"""\n\n-- ---------------------------- |
|
479 |
-- Records of {table_name.lower()} |
|
480 |
-- ---------------------------- |
|
481 |
-- @formatter:off |
|
482 |
{inserts_lines} |
|
483 |
COMMIT; |
|
484 |
-- @formatter:on""" |
|
485 |
match = re.search(r"VALUES \((\d+),", inserts[-1]) |
|
486 |
if match: |
|
487 |
last_id = int(match.group(1)) |
|
488 |
|
|
489 |
# 生成 Sequence |
|
490 |
script += f""" |
|
491 |
|
|
492 |
CREATE SEQUENCE {table_name}_seq |
|
493 |
START WITH {last_id + 1};""" |
|
494 |
|
|
495 |
return script |
|
496 |
|
|
497 |
|
|
498 |
class SQLServerConvertor(Convertor): |
|
499 |
"""_summary_ |
|
500 |
|
|
501 |
Args: |
|
502 |
Convertor (_type_): _description_ |
|
503 |
""" |
|
504 |
|
|
505 |
def __init__(self, src): |
|
506 |
super().__init__(src, "Microsoft SQL Server") |
|
507 |
|
|
508 |
def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]): |
|
509 |
"""类型转换""" |
|
510 |
|
|
511 |
type = type.lower() |
|
512 |
|
|
513 |
if type == "varchar": |
|
514 |
return f"nvarchar({size if size < 4000 else 4000})" |
|
515 |
if type == "int": |
|
516 |
return "int" |
|
517 |
if type == "bigint" or type == "bigint unsigned": |
|
518 |
return "bigint" |
|
519 |
if type == "datetime": |
|
520 |
return "datetime2" |
|
521 |
if type == "bit": |
|
522 |
return "varchar(1)" |
|
523 |
if type in ("tinyint", "smallint"): |
|
524 |
return "tinyint" |
|
525 |
if type == "text": |
|
526 |
return "nvarchar(max)" |
|
527 |
if type in ("blob", "mediumblob"): |
|
528 |
return "varbinary(max)" |
|
529 |
if type == "decimal": |
|
530 |
return ( |
|
531 |
f"numeric({','.join(str(s) for s in size)})" if len(size) else "numeric" |
|
532 |
) |
|
533 |
|
|
534 |
def gen_create(self, ddl: Dict) -> str: |
|
535 |
"""生成 create""" |
|
536 |
|
|
537 |
def _generate_column(col): |
|
538 |
name = col["name"].lower() |
|
539 |
if name == "id": |
|
540 |
return "id bigint NOT NULL PRIMARY KEY IDENTITY" |
|
541 |
if name == "deleted": |
|
542 |
return "deleted bit DEFAULT 0 NOT NULL" |
|
543 |
|
|
544 |
type = col["type"].lower() |
|
545 |
full_type = self.translate_type(type, col["size"]) |
|
546 |
nullable = "NULL" if col["nullable"] else "NOT NULL" |
|
547 |
default = f"DEFAULT {col['default']}" if col["default"] is not None else "" |
|
548 |
return f"{name} {full_type} {default} {nullable}" |
|
549 |
|
|
550 |
table_name = ddl["table_name"].lower() |
|
551 |
columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]] |
|
552 |
filed_def_list = ",\n ".join(columns) |
|
553 |
script = f"""-- ---------------------------- |
|
554 |
-- Table structure for {table_name} |
|
555 |
-- ---------------------------- |
|
556 |
DROP TABLE IF EXISTS {table_name}; |
|
557 |
CREATE TABLE {table_name} ( |
|
558 |
{filed_def_list} |
|
559 |
) |
|
560 |
GO""" |
|
561 |
|
|
562 |
return script |
|
563 |
|
|
564 |
def gen_comment(self, table_sql: str, table_name: str) -> str: |
|
565 |
"""生成字段及表的注释""" |
|
566 |
|
|
567 |
script = "" |
|
568 |
|
|
569 |
for field, comment_string in self.filed_comments(table_sql): |
|
570 |
script += f"""EXEC sp_addextendedproperty |
|
571 |
'MS_Description', N'{comment_string}', |
|
572 |
'SCHEMA', N'dbo', |
|
573 |
'TABLE', N'{table_name}', |
|
574 |
'COLUMN', N'{field}' |
|
575 |
GO |
|
576 |
|
|
577 |
""" |
|
578 |
|
|
579 |
table_comment = self.table_comment(table_sql) |
|
580 |
if table_comment: |
|
581 |
script += f"""EXEC sp_addextendedproperty |
|
582 |
'MS_Description', N'{table_comment}', |
|
583 |
'SCHEMA', N'dbo', |
|
584 |
'TABLE', N'{table_name}' |
|
585 |
GO |
|
586 |
|
|
587 |
""" |
|
588 |
return script |
|
589 |
|
|
590 |
def gen_pk(self, table_name: str) -> str: |
|
591 |
"""生成主键定义""" |
|
592 |
return "" |
|
593 |
|
|
594 |
def gen_index(self, ddl: Dict) -> str: |
|
595 |
"""生成 index""" |
|
596 |
return "\n".join(f"{script}\nGO" for script in self.index(ddl)) |
|
597 |
|
|
598 |
def gen_insert(self, table_name: str) -> str: |
|
599 |
"""生成 insert 语句""" |
|
600 |
|
|
601 |
# 收集 `table_name` 对应的 insert 语句 |
|
602 |
inserts = [] |
|
603 |
for insert_script in Convertor.inserts(table_name, self.content): |
|
604 |
# SQLServer: 字符串前加N,hack,是否存在替换字符串内容的风险 |
|
605 |
insert_script = insert_script.replace(", '", ", N'").replace( |
|
606 |
"VALUES ('", "VALUES (N')" |
|
607 |
) |
|
608 |
# 删除 insert 的结尾分号 |
|
609 |
insert_script = re.sub(";$", r"\nGO", insert_script) |
|
610 |
inserts.append(insert_script) |
|
611 |
|
|
612 |
## 生成 insert 脚本 |
|
613 |
script = "" |
|
614 |
if inserts: |
|
615 |
inserts_lines = "\n".join(inserts) |
|
616 |
script += f"""\n\n-- ---------------------------- |
|
617 |
-- Records of {table_name.lower()} |
|
618 |
-- ---------------------------- |
|
619 |
-- @formatter:off |
|
620 |
BEGIN TRANSACTION |
|
621 |
GO |
|
622 |
SET IDENTITY_INSERT {table_name.lower()} ON |
|
623 |
GO |
|
624 |
{inserts_lines} |
|
625 |
SET IDENTITY_INSERT {table_name.lower()} OFF |
|
626 |
GO |
|
627 |
COMMIT |
|
628 |
GO |
|
629 |
-- @formatter:on""" |
|
630 |
|
|
631 |
return script |
|
632 |
|
|
633 |
def gen_dual(self) -> str: |
|
634 |
return """DROP TABLE IF EXISTS dual |
|
635 |
GO |
|
636 |
|
|
637 |
CREATE TABLE dual |
|
638 |
( |
|
639 |
id int NULL |
|
640 |
) |
|
641 |
GO |
|
642 |
|
|
643 |
EXEC sp_addextendedproperty |
|
644 |
'MS_Description', N'数据库连接的表', |
|
645 |
'SCHEMA', N'dbo', |
|
646 |
'TABLE', N'dual' |
|
647 |
GO""" |
|
648 |
|
|
649 |
|
|
650 |
class DM8Convertor(Convertor): |
|
651 |
def __init__(self, src): |
|
652 |
super().__init__(src, "DM8") |
|
653 |
|
|
654 |
def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]): |
|
655 |
"""类型转换""" |
|
656 |
type = type.lower() |
|
657 |
|
|
658 |
if type == "varchar": |
|
659 |
return f"varchar({size})" |
|
660 |
if type == "int": |
|
661 |
return "int" |
|
662 |
if type == "bigint" or type == "bigint unsigned": |
|
663 |
return "bigint" |
|
664 |
if type == "datetime": |
|
665 |
return "datetime" |
|
666 |
if type == "bit": |
|
667 |
return "bit" |
|
668 |
if type in ("tinyint", "smallint"): |
|
669 |
return "smallint" |
|
670 |
if type == "text": |
|
671 |
return "text" |
|
672 |
if type == "blob": |
|
673 |
return "blob" |
|
674 |
if type == "mediumblob": |
|
675 |
return "varchar(10240)" |
|
676 |
if type == "decimal": |
|
677 |
return ( |
|
678 |
f"decimal({','.join(str(s) for s in size)})" if len(size) else "decimal" |
|
679 |
) |
|
680 |
|
|
681 |
def gen_create(self, ddl) -> str: |
|
682 |
"""生成 CREATE 语句""" |
|
683 |
|
|
684 |
def generate_column(col): |
|
685 |
name = col["name"].lower() |
|
686 |
if name == "id": |
|
687 |
return "id bigint NOT NULL PRIMARY KEY IDENTITY" |
|
688 |
|
|
689 |
type = col["type"].lower() |
|
690 |
full_type = self.translate_type(type, col["size"]) |
|
691 |
nullable = "NULL" if col["nullable"] else "NOT NULL" |
|
692 |
default = f"DEFAULT {col['default']}" if col["default"] is not None else "" |
|
693 |
return f"{name} {full_type} {default} {nullable}" |
|
694 |
|
|
695 |
table_name = ddl["table_name"].lower() |
|
696 |
columns = [f"{generate_column(col).strip()}" for col in ddl["columns"]] |
|
697 |
field_def_list = ",\n ".join(columns) |
|
698 |
script = f"""-- ---------------------------- |
|
699 |
-- Table structure for {table_name} |
|
700 |
-- ---------------------------- |
|
701 |
CREATE TABLE {table_name} ( |
|
702 |
{field_def_list} |
|
703 |
);""" |
|
704 |
|
|
705 |
# oracle INSERT '' 不能通过 NOT NULL 校验 |
|
706 |
script = script.replace("DEFAULT '' NOT NULL", "DEFAULT '' NULL") |
|
707 |
|
|
708 |
return script |
|
709 |
|
|
710 |
def gen_index(self, ddl: Dict) -> str: |
|
711 |
return "\n".join(f"{script};" for script in self.index(ddl)) |
|
712 |
|
|
713 |
def gen_comment(self, table_sql: str, table_name: str) -> str: |
|
714 |
script = "" |
|
715 |
for field, comment_string in self.filed_comments(table_sql): |
|
716 |
script += ( |
|
717 |
f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n" |
|
718 |
) |
|
719 |
|
|
720 |
table_comment = self.table_comment(table_sql) |
|
721 |
if table_comment: |
|
722 |
script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n" |
|
723 |
|
|
724 |
return script |
|
725 |
|
|
726 |
def gen_pk(self, table_name: str) -> str: |
|
727 |
"""生成主键定义""" |
|
728 |
return "" |
|
729 |
|
|
730 |
def gen_index(self, ddl: Dict) -> str: |
|
731 |
return "\n".join(f"{script};" for script in self.index(ddl)) |
|
732 |
|
|
733 |
def gen_insert(self, table_name: str) -> str: |
|
734 |
"""拷贝 INSERT 语句""" |
|
735 |
inserts = list(Convertor.inserts(table_name, self.content)) |
|
736 |
|
|
737 |
## 生成 insert 脚本 |
|
738 |
script = "" |
|
739 |
if inserts: |
|
740 |
inserts_lines = "\n".join(inserts) |
|
741 |
script += f"""\n\n-- ---------------------------- |
|
742 |
-- Records of {table_name.lower()} |
|
743 |
-- ---------------------------- |
|
744 |
-- @formatter:off |
|
745 |
SET IDENTITY_INSERT {table_name.lower()} ON; |
|
746 |
{inserts_lines} |
|
747 |
COMMIT; |
|
748 |
SET IDENTITY_INSERT {table_name.lower()} OFF; |
|
749 |
-- @formatter:on""" |
|
750 |
|
|
751 |
return script |
|
752 |
|
|
753 |
|
|
754 |
def main(): |
|
755 |
parser = argparse.ArgumentParser(description="平台系统数据库转换工具") |
|
756 |
parser.add_argument( |
|
757 |
"type", |
|
758 |
type=str, |
|
759 |
help="目标数据库类型", |
|
760 |
choices=["postgres", "oracle", "sqlserver", "dm8"], |
|
761 |
) |
|
762 |
args = parser.parse_args() |
|
763 |
|
|
764 |
sql_file = pathlib.Path("../mysql/ruoyi-vue-pro.sql").resolve().as_posix() |
|
765 |
convertor = None |
|
766 |
if args.type == "postgres": |
|
767 |
convertor = PostgreSQLConvertor(sql_file) |
|
768 |
elif args.type == "oracle": |
|
769 |
convertor = OracleConvertor(sql_file) |
|
770 |
elif args.type == "sqlserver": |
|
771 |
convertor = SQLServerConvertor(sql_file) |
|
772 |
elif args.type == "dm8": |
|
773 |
convertor = DM8Convertor(sql_file) |
|
774 |
else: |
|
775 |
raise NotImplementedError(f"不支持目标数据库类型: {args.type}") |
|
776 |
|
|
777 |
convertor.print() |
|
778 |
|
|
779 |
|
|
780 |
if __name__ == "__main__": |
|
781 |
main() |