from sshtunnel import SSHTunnelForwarder from sqlalchemy.orm import sessionmaker from sqlalchemy import create_engine import pandas as pd from urllib import parse from pydoris.doris_client import DorisClient import json import datetime import numpy as np # 数仓相关库函数 class DwService(): def __init__(self, tables, doris_host, doris_port, doris_user, doris_password, doris_db, doris_table_format, mysql_host, mysql_port, mysql_user, mysql_password, mysql_db, start_time="", end_time="", ssh_host="", ssh_name="", ssh_password="", settings={}): self.tables = tables self.doris_host = doris_host self.doris_port = doris_port self.doris_user = doris_user self.doris_password = doris_password self.doris_db = doris_db self.doris_table_format = doris_table_format self.mysql_host = mysql_host self.mysql_port = int(mysql_port) self.mysql_user = mysql_user self.mysql_password = mysql_password self.mysql_db = mysql_db self.start_time = start_time self.end_time = end_time self.ssh_host = ssh_host self.ssh_name = ssh_name self.ssh_password = ssh_password settings_default = {"skip_endtime_0000":"true","skip_endtime_1970":"true"} for k,v in settings.items(): settings_default.update({k:v}) self.settings = settings_default def _get_db(self): self.doris_client = DorisClient(self.doris_host, self.doris_port, self.doris_user, self.doris_password) self.doris_client.options\ .set_json_format()\ .set_auto_uuid_label()\ .set_option('strip_outer_array', 'true') self.db_engine = create_engine( "mysql+pymysql://{}:{}@{}:{}/{}?charset=utf8".format(self.mysql_user,parse.quote_plus(self.mysql_password), self.mysql_host, int(self.mysql_port), self.mysql_db),pool_recycle=3600,pool_size=5) self.Session = sessionmaker(self.db_engine) def _get_and_write(self, session, sql, table, ): # 从mysql取数,处理然后写入doris print(f'{datetime.datetime.now()}-{self.doris_table_format.format(table)} 数据开始同步...............') self.doris_client.options.set_auto_uuid_label() df_desc = pd.read_sql(f'desc {table}', self.db_engine) res = session.execute(sql).fetchall() columns = df_desc['Field'].tolist() types = df_desc['Type'].tolist() res_dict = [] # 字典列表组装 for r in res: mp = {} i = 0 skip_flag = False for item in r: # 数据类型特殊处理 if (str(item) == '0000-00-00 00:00:00'): # done表中存在endtime=0000的数据,影响索引字段 if self.settings.get("skip_endtime_0000") == "true": # 跳过endtime0000的数据 skip_flag = True if ('1970-01-01' in str(item) ): # done表中存在endtime=0000的数据,影响索引字段 if self.settings.get("skip_endtime_1970") == "true": # 跳过endtime0000的数据 skip_flag = True if types[i] == 'datetime' or types[i] == 'date' or types[i] == 'timestamp': mp.update({columns[i]: str(item)}) elif types[i] == 'tinyint': mp.update({columns[i]: 0 if item == 'false' else 1 if item == 'true' else item}) elif 'decimal' in types[i] : mp.update({columns[i]: str(item)}) else: mp.update({columns[i]:item}) i = i + 1 if not skip_flag: res_dict.append(mp) if len(res_dict) > 0 : print(f'{datetime.datetime.now()}-{self.doris_table_format.format(table)} 数据开始导入doris...............') self.doris_client._session.should_strip_auth = lambda old_url, new_url: False resp = self.doris_client._session.request( 'PUT', url=self.doris_client._build_url(self.doris_db, self.doris_table_format.format(table)), data=json.dumps(res_dict).encode('utf-8'), headers=self.doris_client.options.get_options(), auth=self.doris_client._auth ) load_status = json.loads(resp.text)['Status'] == 'Success' if resp.status_code == 200 and resp.reason == 'OK' and load_status: print(f'{datetime.datetime.now()}-{self.doris_table_format.format(table)} 数据导入成功') print(resp.text) else: print(f'{datetime.datetime.now()}-{self.doris_table_format.format(table)} 数据导入失败') raise Exception(resp.text) else: print(f"{datetime.datetime.now()}-本次需同步的数据条数0") def full_table_sync(self): # 全量更新 sql_format = 'select * from {}' self._get_db() with self.Session() as session: for table in self.tables: sql = sql_format.format(table) self._get_and_write(session, sql, table) def inc_table_sync(self): # 增量更新 if self.start_time == "" or self.end_time == "": raise Exception("输入参数错误") sql_format = "select * from {} where {} >= '{}' and {} < '{}'" self._get_db() with self.Session() as session: for table_dict in self.tables: for table, value in table_dict.items(): time_field = value.get('time_field') sql = sql_format.format(table, time_field, self.start_time, time_field, self.end_time) self._get_and_write(session, sql, table) def full_table_sync_with_ssh(self): with SSHTunnelForwarder(self.ssh_host, 22, ssh_username=self.ssh_name, ssh_password=self.ssh_password, remote_bind_address=(self.mysql_host, self.mysql_port)) as tunnel: self.mysql_host = '127.0.0.1' self.mysql_port = tunnel.local_bind_port self.full_table_sync() def inc_table_sync_with_ssh(self): if self.ssh_host == "" or self.ssh_name == "" or self.ssh_password == "": raise Exception("输入参数错误") with SSHTunnelForwarder(self.ssh_host, 22, ssh_username=self.ssh_name, ssh_password=self.ssh_password, remote_bind_address=(self.mysql_host, self.mysql_port)) as tunnel: self.mysql_host = '127.0.0.1' self.mysql_port = tunnel.local_bind_port self.inc_table_sync()