123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- from sshtunnel import SSHTunnelForwarder
- from sqlalchemy.orm import sessionmaker
- from sqlalchemy import create_engine
- import pandas as pd
- from urllib import parse
- from pydoris.doris_client import DorisClient
- import json
- import datetime
- import numpy as np
- # 数仓相关库函数
- class DwService():
-
- def __init__(self, tables, doris_host, doris_port, doris_user, doris_password, doris_db,
- doris_table_format, mysql_host, mysql_port, mysql_user, mysql_password, mysql_db, start_time="", end_time="", ssh_host="", ssh_name="", ssh_password="",
- settings={}):
- self.tables = tables
- self.doris_host = doris_host
- self.doris_port = doris_port
- self.doris_user = doris_user
- self.doris_password = doris_password
- self.doris_db = doris_db
- self.doris_table_format = doris_table_format
- self.mysql_host = mysql_host
- self.mysql_port = int(mysql_port)
- self.mysql_user = mysql_user
- self.mysql_password = mysql_password
- self.mysql_db = mysql_db
- self.start_time = start_time
- self.end_time = end_time
- self.ssh_host = ssh_host
- self.ssh_name = ssh_name
- self.ssh_password = ssh_password
- settings_default = {"skip_endtime_0000":"true","skip_endtime_1970":"true"}
- for k,v in settings.items():
- settings_default.update({k:v})
- self.settings = settings_default
-
- def _get_db(self):
- self.doris_client = DorisClient(self.doris_host, self.doris_port, self.doris_user, self.doris_password)
- self.doris_client.options\
- .set_json_format()\
- .set_auto_uuid_label()\
- .set_option('strip_outer_array', 'true')
- self.db_engine = create_engine(
- "mysql+pymysql://{}:{}@{}:{}/{}?charset=utf8".format(self.mysql_user,parse.quote_plus(self.mysql_password),
- self.mysql_host, int(self.mysql_port), self.mysql_db),pool_recycle=3600,pool_size=5)
- self.Session = sessionmaker(self.db_engine)
-
- def _get_and_write(self, session, sql, table, ):
- # 从mysql取数,处理然后写入doris
- print(f'{datetime.datetime.now()}-{self.doris_table_format.format(table)} 数据开始同步...............')
- self.doris_client.options.set_auto_uuid_label()
- df_desc = pd.read_sql(f'desc {table}', self.db_engine)
- res = session.execute(sql).fetchall()
- columns = df_desc['Field'].tolist()
- types = df_desc['Type'].tolist()
- res_dict = []
- # 字典列表组装
- for r in res:
- mp = {}
- i = 0
- skip_flag = False
- for item in r:
- # 数据类型特殊处理
- if (str(item) == '0000-00-00 00:00:00'): # done表中存在endtime=0000的数据,影响索引字段
- if self.settings.get("skip_endtime_0000") == "true": # 跳过endtime0000的数据
- skip_flag = True
- if ('1970-01-01' in str(item) ): # done表中存在endtime=0000的数据,影响索引字段
- if self.settings.get("skip_endtime_1970") == "true": # 跳过endtime0000的数据
- skip_flag = True
- if types[i] == 'datetime' or types[i] == 'date' or types[i] == 'timestamp':
- mp.update({columns[i]: str(item)})
- elif types[i] == 'tinyint':
- mp.update({columns[i]: 0 if item == 'false' else 1 if item == 'true' else item})
- elif 'decimal' in types[i] :
- mp.update({columns[i]: str(item)})
- else:
- mp.update({columns[i]:item})
- i = i + 1
- if not skip_flag:
- res_dict.append(mp)
- if len(res_dict) > 0 :
- print(f'{datetime.datetime.now()}-{self.doris_table_format.format(table)} 数据开始导入doris...............')
- self.doris_client._session.should_strip_auth = lambda old_url, new_url: False
- resp = self.doris_client._session.request(
- 'PUT', url=self.doris_client._build_url(self.doris_db, self.doris_table_format.format(table)),
- data=json.dumps(res_dict).encode('utf-8'),
- headers=self.doris_client.options.get_options(),
- auth=self.doris_client._auth
- )
- load_status = json.loads(resp.text)['Status'] == 'Success'
- if resp.status_code == 200 and resp.reason == 'OK' and load_status:
- print(f'{datetime.datetime.now()}-{self.doris_table_format.format(table)} 数据导入成功')
- print(resp.text)
- else:
- print(f'{datetime.datetime.now()}-{self.doris_table_format.format(table)} 数据导入失败')
- raise Exception(resp.text)
- else:
- print(f"{datetime.datetime.now()}-本次需同步的数据条数0")
- def full_table_sync(self):
- # 全量更新
- sql_format = 'select * from {}'
- self._get_db()
- with self.Session() as session:
- for table in self.tables:
- sql = sql_format.format(table)
- self._get_and_write(session, sql, table)
-
- def inc_table_sync(self):
- # 增量更新
- if self.start_time == "" or self.end_time == "":
- raise Exception("输入参数错误")
- sql_format = "select * from {} where {} >= '{}' and {} < '{}'"
- self._get_db()
- with self.Session() as session:
- for table_dict in self.tables:
- for table, value in table_dict.items():
- time_field = value.get('time_field')
- sql = sql_format.format(table, time_field, self.start_time, time_field, self.end_time)
- self._get_and_write(session, sql, table)
-
- def full_table_sync_with_ssh(self):
- with SSHTunnelForwarder(self.ssh_host, 22, ssh_username=self.ssh_name, ssh_password=self.ssh_password, remote_bind_address=(self.mysql_host, self.mysql_port)) as tunnel:
- self.mysql_host = '127.0.0.1'
- self.mysql_port = tunnel.local_bind_port
- self.full_table_sync()
-
- def inc_table_sync_with_ssh(self):
- if self.ssh_host == "" or self.ssh_name == "" or self.ssh_password == "":
- raise Exception("输入参数错误")
- with SSHTunnelForwarder(self.ssh_host, 22, ssh_username=self.ssh_name, ssh_password=self.ssh_password, remote_bind_address=(self.mysql_host, self.mysql_port)) as tunnel:
- self.mysql_host = '127.0.0.1'
- self.mysql_port = tunnel.local_bind_port
- self.inc_table_sync()
|