main_train.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. import json
  2. import traceback
  3. import os
  4. from apscheduler.schedulers.blocking import BlockingScheduler
  5. import pandas as pd
  6. from ZlwlAlgosCommon.utils.ProUtils import *
  7. from ZlwlAlgosCommon.service.iotp.IotpAlgoService import IotpAlgoService
  8. from ZlwlAlgosCommon.service.iotp.Beans import DataField
  9. from train2 import *
  10. #读取参数
  11. def download(sn,start_time,end_time,df_snpk_list,df_algo_pack_param):
  12. snindex = list(df_snpk_list['sn']).index(sn)
  13. for idx in df_tags_dataset.index:
  14. d = df_tags_dataset.loc[idx]
  15. factory_id = d['factory_id']
  16. user = d['create_by']
  17. sn_list = [d['sn']]
  18. start_time = str(d['start_time'])
  19. end_time = str(d['end_time'])
  20. # print(sn_list)
  21. df_data = iotp_datafactory_service.get_data(sn_list=sn_list, columns=columns, start_time=start_time, end_time=end_time,factory=factory_id,user=user)
  22. # print(df_data)
  23. #break
  24. dataset=dataset.append(df_data)
  25. df_data['sn']=sn
  26. df_data['imei']=df_snpk_list['imei'][snindex]
  27. df_data2=pd.DataFrame()
  28. if len(df_data)>0:
  29. #if df_data.loc[0,'sn'][:5]=='PK504':
  30. if df_data['imei'].isnull().any():
  31. df_data['imei']=sn
  32. pack_code = df_snpk_list['pack_model'][snindex]
  33. df_pack_param= df_algo_pack_param[df_algo_pack_param['pack_code']==pack_code]
  34. if len(df_pack_param)>0:
  35. celpack_param=json.loads(df_pack_param.iloc[0]['param'])
  36. cellnum = celpack_param['CellVoltTotalCount']
  37. tempnum = celpack_param['CellTempTotalCount']
  38. capacity = celpack_param['capacity']
  39. df_data,df_table,cellvolt_name,celltemp_name=DataClean.datacleaning(df_data,cellnum,tempnum)
  40. df_data2=features_total(df_data,capacity)
  41. return df_data2
  42. #下载全部错误数据
  43. def makedf(df,df_snpk_list,df_algo_pack_param):
  44. dataset=pd.DataFrame()
  45. df.reset_index(drop=True,inplace=True)
  46. split=0
  47. for k in range(len(df)):
  48. #for k in range(50):#for k in range(len(df)):
  49. try:
  50. sn =df.loc[k,'sn']
  51. start_time=str(df.loc[k,'start_time'])
  52. end_time=str(df.loc[k,'end_time'])
  53. if end_time=='NaT':
  54. end_time=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') #type: str
  55. df_data2=download(sn,start_time,end_time,df_snpk_list,df_algo_pack_param)
  56. split=split+1
  57. df_data2['split']=split
  58. dataset=dataset.append(df_data2)
  59. except Exception as e:
  60. logger.error('故障'+sn+' '+start_time+' '+str(e))
  61. logger.error(traceback.format_exc())
  62. pass
  63. return dataset
  64. #随机下载正常数据
  65. def makedf_nor(df_snpk_list,df_algo_pack_param):
  66. SNnums = list(df_snpk_list['sn'])
  67. dataset=pd.DataFrame()
  68. split=0
  69. for sn in list(set(SNnums))[:100]:
  70. try:
  71. snindex = list(df_snpk_list['sn']).index(sn)
  72. now_time=datetime.datetime.now()
  73. start_time=now_time-datetime.timedelta(hours=random.randint(1,365*24))
  74. end_time=str(start_time+datetime.timedelta(hours=12))[:19]
  75. start_time=start_time.strftime('%Y-%m-%d %H:%M:%S')
  76. df_data2=download(sn,start_time,end_time,df_snpk_list,df_algo_pack_param)
  77. split=split+1
  78. df_data2['split']=split
  79. dataset=dataset.append(df_data2)
  80. except Exception as e:
  81. logger.error('正常'+sn+' '+start_time+' '+str(e))
  82. logger.error(traceback.format_exc())
  83. pass
  84. return dataset
  85. cur_env = 'dev' # 设置运行环境
  86. app_path = "/home/zhuxi/project/zlwl-algos/" # 设置app绝对路径
  87. app_name = ""
  88. log_base_path = f"{os.path.dirname(os.path.abspath(__file__))}/log" # 设置日志路径
  89. sysUtils = SysUtils(cur_env, app_path)
  90. logger = sysUtils.get_logger(app_name, log_base_path)
  91. # mysql
  92. mysql_algo_params = sysUtils.get_cf_param('mysql-algo')
  93. mysqlUtils = MysqlUtils()
  94. mysql_algo_engine, mysql_algo_Session= mysqlUtils.get_mysql_engine(mysql_algo_params)
  95. db_engine = mysql_algo_engine.connect()
  96. # redis
  97. redis_params = sysUtils.get_cf_param('redis')
  98. redisUtils = RedisUtils()
  99. rc = redisUtils.get_redis_conncect(redis_params)
  100. hbase_params = sysUtils.get_cf_param('hbase')
  101. iotp_service = IotpAlgoService(hbase_params=hbase_params)
  102. # dao=Dao()
  103. # sys_utils = SysUtils()
  104. # mysql_user, mysql_password, mysql_host, mysql_port, mysql_db = sys_utils.get_mysql()
  105. # rc = sys_utils.get_redis()
  106. # db_engine = create_engine("mysql+pymysql://{}:{}@{}:{}/{}?charset=utf8".format(mysql_user, mysql_password, mysql_host, mysql_port, mysql_db),pool_recycle=7200,pool_size=2)
  107. Session = sessionmaker(bind=db_engine)
  108. df_algo_adjustable_param, df_algo_list, df_algo_pack_param, df_snpk_list,df_snpk_list_scrap = update_param(db_engine, rc)
  109. df_snpk_list=pd.concat([df_snpk_list,df_snpk_list_scrap])
  110. df_snpk_list.reset_index(drop=True,inplace=True)
  111. t_tag_child, r_battery_tag, t_algo_alarm_data_tag = update_lable(db_engine, rc)
  112. list_fault=list(t_tag_child[t_tag_child['tag_type']==3]['name'])
  113. # sel_columns = [packageInfo.Time, batteryStatus.PackCrnt, batteryStatus.PackVolt, batteryStatus.PackSoc, batteryStatus.PackSoh,
  114. # batteryStatus.InsulationRssPos, batteryStatus.InsulationRssNeg, batteryStatus.BMSSta, batteryStatus.AccumChrgWh,
  115. # batteryStatus.CellVoltage, batteryStatus.CellTemp, batteryStatus.OtherTempName, batteryStatus.OtherTempValue ,batteryStatus.AccumChrgAh,batteryStatus.AccumDsChgAh]
  116. sel_columns = [DataField.error_level, DataField.error_code, DataField.pack_crnt, DataField.pack_volt,
  117. DataField.bms_sta, DataField.cell_voltage_count, DataField.cell_temp_count, DataField.cell_voltage, DataField.cell_temp,
  118. DataField.pack_soc, DataField.other_temp_value, DataField.cell_balance,
  119. DataField.pack_soh, DataField.charge_sta]
  120. df_data = iotp_service.get_data(sn_list=sn_list, columns=columns, start_time=start_time, end_time=end_time)
  121. iotp_service.data_clean(df_data)
  122. def diag_cal():
  123. global logger
  124. for fault in list_fault:
  125. list_models=os.listdir('Resources/AI_Fault_Class/V_auto/models')
  126. if 'model_'+fault+'.h5' not in list_models:
  127. parent_id=int(t_tag_child[t_tag_child['name']==fault]['parent_id'].values[0])
  128. child_id=str(t_tag_child[t_tag_child['name']==fault]['id'].values[0])
  129. object_id=list(r_battery_tag[(r_battery_tag['tag_id']==parent_id) & (r_battery_tag['child_tag_list']==child_id)]['object_id'])
  130. df_diag_ram=t_algo_alarm_data_tag[t_algo_alarm_data_tag['id'].isin(object_id)]
  131. #判断样本足够
  132. if len(df_diag_ram)>50:
  133. #收集样本
  134. logger.info("下载样本")
  135. #下载全部错误数据
  136. datatest=makedf(df_diag_ram,df_snpk_list,df_algo_pack_param) #deltatime以秒为单位
  137. datatest.fillna(datatest.median(),inplace=True) # 填充中位数
  138. datatest.to_csv('datatest'+fault+'.csv')
  139. #datatest=pd.read_csv('datatest'+fault+'.csv')
  140. #datatest=datatest.drop(['Unnamed: 0'],axis=1)
  141. #随机下载正常数据
  142. dataset_nor=makedf_nor(df_snpk_list,df_algo_pack_param)
  143. dataset_nor.to_csv('datatestnor.csv')
  144. #dataset_nor=pd.read_csv('datatestnor.csv')
  145. #dataset_nor=dataset_nor.drop(['Unnamed: 0'],axis=1)
  146. median=dataset_nor.median()
  147. median2=str(median[['PackVolt','BMSSta','PackSoc','temp_max','temp_min','temp_mean','temp_diff','temp2_max','temp2_min', 'temp2_mean', 'temp2_diff']].to_dict())
  148. #print(str(median[['PackVolt','BMSSta','PackSoc','temp_max','temp_min','temp_mean','temp_diff','temp2_max','temp2_min', 'temp2_mean', 'temp2_diff']].to_dict()))
  149. dataset_nor.fillna(median,inplace=True) # 填充中位数
  150. logger.info("模型训练")
  151. #自动训练
  152. model,scaler,loss_th_max,loss_th_sum,time_steps,key_col=train(datatest,dataset_nor)
  153. logger.info("参数及模型保存")
  154. #保存参数
  155. #更新algo_list表
  156. list_fault_hist=list(df_algo_list[(df_algo_list['algo_id']>100)&(df_algo_list['algo_id']<200)]['algo_name'])
  157. list_pack_code=list(df_algo_adjustable_param['pack_code'].drop_duplicates())
  158. if fault+'_AI' not in list_fault_hist:
  159. id=max(list(set(df_algo_list['id'])))+1
  160. algo_id=max(list(df_algo_list[(df_algo_list['algo_id']>100)&(df_algo_list['algo_id']<200)]['algo_id']))+1
  161. create_time=str(datetime.datetime.now())
  162. fault_code='C'+str(int(max(list(set(df_algo_list[(df_algo_list['fault_code']>'C250')&(df_algo_list['fault_code']<'C300')]['fault_code'])))[1:])+1)
  163. input_param2=pd.DataFrame({'id':[id],'create_time':[create_time],'create_by':['zhuxi'],'algo_id':[algo_id],'algo_name':[fault+'_AI'],'is_activate':[1],'fault_level':[2],'fault_code':[fault_code],'fault_influence':['存在安全风险'],'model_type':[0],'configurable_flag':[str(1100111)],'is_delete':[0],'model_alarm_type':[1]})
  164. input_param2.to_sql("algo_list",con=db_engine, if_exists="append",index=False)
  165. id2=list(df_algo_adjustable_param[df_algo_adjustable_param['algo_id']==algo_id]['id'])
  166. else:
  167. algo_id=df_algo_list[df_algo_list['algo_name']==fault+'_AI']['algo_id'].values[0]
  168. id2=list(range(max(list(set(df_algo_adjustable_param['id'])))+1,max(list(set(df_algo_adjustable_param['id'])))+1+len(list_pack_code)))
  169. #更新algo_adjustable_param表
  170. param_ai={"time_steps":str(time_steps),"median":median2}
  171. param={"key_feature":key_col,"loss_max":str(loss_th_max),"loss_sum":str(loss_th_sum)}
  172. input_param=pd.DataFrame({'id':id2,'algo_id':algo_id,'pack_code':list_pack_code,'param':str(param),'param_ai':str(param_ai)})
  173. session = Session()
  174. session.execute("DELETE FROM algo_adjustable_param WHERE algo_id ='{}'".format(algo_id))
  175. session.commit()
  176. input_param.to_sql("algo_adjustable_param",con=db_engine, if_exists="append",index=False)
  177. #保存模型
  178. pickle.dump(scaler,open('Resources/AI_Fault_Class/V_auto/scalers/scaler_'+fault+'.pkl','wb'))
  179. model.save('Resources/AI_Fault_Class/V_auto/models/model_'+fault+'.h5')
  180. import datetime
  181. import gc
  182. import re
  183. from multiprocessing import Pool
  184. import json
  185. import logging
  186. import logging.handlers
  187. import os
  188. import time
  189. import traceback
  190. import warnings
  191. from sqlalchemy import text, delete, and_, or_, update
  192. import pandas as pd
  193. from ZlwlAlgosCommon.utils.ProUtils import *
  194. from ZlwlAlgosCommon.service.iotp.IotpAlgoService import IotpAlgoService
  195. from ZlwlAlgosCommon.service.iotp.Beans import DataField
  196. from ZlwlAlgosCommon.orm.models import *
  197. def invoke_algo1(logger, mysql_algo_conn, mysql_algo_Session, start_time, df_data):
  198. pass
  199. def invoke_algo2(logger, mysql_algo_conn, mysql_algo_Session, start_time, df_data):
  200. pass
  201. def main(process_num):
  202. # 程序不能停止
  203. while(True):
  204. warnings.filterwarnings("ignore")
  205. try:
  206. # 调用算法前的准备工作
  207. kafka_topic_key = 'topic_task_month_1'
  208. kafka_groupid_key = 'group_task_month_1'
  209. algo_list = ['FaultClass_Train'] # 本调度所包含的算法名列表。
  210. loggers = sysUtils.get_loggers(algo_list, log_base_path, process_num) # 为每个算法分配一个logger
  211. logger_main.info(f"process-{process_num}: 配置中间件")
  212. # mysql
  213. mysql_algo_params = sysUtils.get_cf_param('mysql-algo')
  214. mysqlUtils = MysqlUtils()
  215. mysql_algo_engine, mysql_algo_Session= mysqlUtils.get_mysql_engine(mysql_algo_params)
  216. mysql_algo_conn = mysql_algo_engine.connect()
  217. # redis
  218. redis_params = sysUtils.get_cf_param('redis')
  219. redisUtils = RedisUtils()
  220. redis_conn = redisUtils.get_redis_conncect(redis_params)
  221. # hbase
  222. hbase_params = sysUtils.get_cf_param('hbase')
  223. iotp_service = IotpAlgoService(hbase_params=hbase_params)
  224. # kafka
  225. kafka_params = sysUtils.get_cf_param('kafka')
  226. kafkaUtils = KafkaUtils()
  227. kafka_consumer = kafkaUtils.get_kafka_consumer(kafka_params, kafka_topic_key, kafka_groupid_key, client_id=kafka_topic_key)
  228. logger_main.info(f"process-{process_num}: 获取算法参数及电池参数")
  229. except Exception as e:
  230. logger_main.error(f'process-{process_num}: {e}')
  231. logger_main.error(f'process-{process_num}: {traceback.format_exc()}')
  232. # 开始准备调度
  233. try:
  234. logger_main.info(f"process-{process_num}: 监听topic {kafka_params[kafka_topic_key]}等待kafka 调度")
  235. param_update_timer = time.time()
  236. for message in kafka_consumer:
  237. try:
  238. logger_main.info(f'process-{process_num}: 收到调度 {message.value}')
  239. if mysql_algo_conn.close:
  240. mysql_algo_conn = mysql_algo_engine.connect() # 从连接池中获取一个myslq连接
  241. schedule_params = json.loads(message.value)
  242. if (schedule_params is None) or (schedule_params ==''):
  243. logger_main.info(f'process-{process_num}: {message.value} kafka数据异常,跳过本次运算')
  244. continue
  245. # kafka 调度参数解析
  246. df_snlist = pd.DataFrame(schedule_params['snlist'])
  247. df_algo_adjustable_param = pd.DataFrame([(d['algo_id'], d['param'],d['param_ai']) for d in schedule_params['adjustable_param']], columns=['algo_id', 'param','param_ai'])
  248. df_algo_pack_param = json.loads(schedule_params['pack_param'][0]['param'])
  249. df_algo_list = pd.DataFrame(schedule_params['algo_list'])
  250. start_time = schedule_params['start_time']
  251. end_time = schedule_params['end_time']
  252. pack_code = schedule_params['pack_code']
  253. cell_type = schedule_params['cell_type']
  254. sn_list=df_snlist['sn'].tolist()
  255. # 获取标签集数据
  256. hbase_params = sysUtils.get_cf_param('hbase')
  257. hbase_datafactory_params = sysUtils.get_cf_param('hbase-datafactory')
  258. iotp_service = IotpAlgoService(hbase_params=hbase_params)
  259. iotp_datafactory_service = IotpAlgoService(hbase_params=hbase_datafactory_params)
  260. mysql_datafactory_params = sysUtils.get_cf_param('mysql-datafactory')
  261. mysqlUtils = MysqlUtils()
  262. mysql_datafactory_engine, mysql_datafactory_Session= mysqlUtils.get_mysql_engine(mysql_datafactory_params)
  263. mysql_datafactory_conn = mysql_datafactory_engine.connect()
  264. df_tags_dataset = iotp_datafactory_service.get_dataset_tags(mysql_datafactory_conn)
  265. # 取数
  266. columns = [DataField.sn,DataField.time,DataField.error_level,DataField.error_code,DataField.pack_crnt,DataField.pack_volt,DataField.bms_sta,DataField.pack_soh,DataField.cell_voltage,
  267. DataField.cell_temp,DataField.pack_soc,DataField.charge_sta,DataField.other_temp_value,DataField.cell_voltage_count,DataField.cell_temp_count]
  268. data = rc.get("algo_param_from_mysql:t_device")
  269. if pd.isnull(data):
  270. df_snpk_list = pd.read_sql("select sn, imei,pack_model from t_device", db_engine)
  271. df_snpk_list.rename(columns={'pack_model':'pack_code'})
  272. else:
  273. df_snpk_list = pd.DataFrame(json.loads(data))
  274. logger_main.info(f"process-{process_num}: 开始取数")
  275. columns = [DataField.error_level, DataField.error_code, DataField.pack_crnt, DataField.pack_volt,
  276. DataField.bms_sta, DataField.cell_voltage_count, DataField.cell_temp_count, DataField.cell_voltage, DataField.cell_temp,
  277. DataField.pack_soc, DataField.other_temp_value, DataField.cell_balance,
  278. DataField.pack_soh, DataField.charge_sta]
  279. df_data = iotp_service.get_data(sn_list=sn_list, columns=columns, start_time=start_time, end_time=end_time)
  280. logger_main.info(f"process-{process_num}: {str(sn_list)}获取到{str(len(df_data))}条数据")
  281. except Exception as e:
  282. logger_main.error(f'process-{process_num}: {pack_code}运行出错')
  283. logger_main.error(f'process-{process_num}: {e}')
  284. logger_main.error(f'process-{process_num}: {traceback.format_exc()}')
  285. try:
  286. # 数据清洗
  287. if len(df_data) == 0:
  288. logger_main.info(f"process-{process_num}: 无数据跳过本次运算")
  289. continue
  290. df_data,df_table,cellvolt_name,celltemp_name=iotp_service.data_clean(df_data,df_algo_pack_param)#进行数据清洗
  291. if len(df_data) == 0:
  292. logger_main.info(f"process-{process_num}: 数据清洗完成, 无有效数据,跳过本次运算")
  293. continue
  294. else:
  295. logger_main.info(f"process-{process_num}: {pack_code}, time_type:{df_data.loc[0, 'time']} ~ {df_data.iloc[-1]['time']}, 数据清洗完成")
  296. except Exception as e:
  297. logger_main.error(f"process-{process_num}:{pack_code}数据清洗出错")
  298. logger_main.error(f"process-{process_num}:{e}")
  299. logger_main.error(f"process-{process_num}:{traceback.format_exc()}")
  300. # 算法调用
  301. try:
  302. invoke_algo1(loggers, mysql_algo_conn, mysql_algo_Session, start_time, df_data)
  303. except Exception as e:
  304. loggers['FaultWarning'].error('{}运行出错'.format(pack_code))
  305. loggers['FaultWarning'].error(str(e))
  306. loggers['FaultWarning'].error(traceback.format_exc())
  307. # 第二个算法调用
  308. try:
  309. invoke_algo2(loggers, mysql_algo_conn, mysql_algo_Session, start_time, df_data)
  310. except Exception as e:
  311. loggers['FaultWarning'].error('{}运行出错'.format(pack_code))
  312. loggers['FaultWarning'].error(str(e))
  313. loggers['FaultWarning'].error(traceback.format_exc())
  314. except Exception as e:
  315. logger_main.error(f'process-{process_num}: {pack_code}运行出错')
  316. logger_main.error(f'process-{process_num}: {e}')
  317. logger_main.error(f'process-{process_num}: {traceback.format_exc()}')
  318. finally:
  319. iotp_service.close()
  320. if __name__ == '__main__':
  321. while(True):
  322. try:
  323. # 配置量
  324. cur_env = 'dev' # 设置运行环境
  325. app_path = "/home/zhuxi/project/zlwl-algos/" # 设置app绝对路径
  326. log_base_path = f"{os.path.dirname(os.path.abspath(__file__))}/log" # 设置日志路径
  327. app_name = "task_second_1" # 应用名
  328. sysUtils = SysUtils(cur_env, app_path)
  329. logger_main = sysUtils.get_logger(app_name, log_base_path)
  330. logger_main.info(f"本次主进程号: {os.getpid()}")
  331. # 读取配置文件 (该部分请不要修改)
  332. processes = int(sysUtils.env_params.get("PROCESS_NUM_PER_NODE", '1')) # 默认为1个进程
  333. pool = Pool(processes = int(processes))
  334. logger_main.info("开始分配子进程")
  335. for i in range(int(processes)):
  336. pool.apply_async(main, (i, ))
  337. pool.close()
  338. logger_main.info("进程分配结束,堵塞主进程")
  339. pool.join()
  340. except Exception as e:
  341. logger_main.error(str(e))
  342. logger_main.error(traceback.format_exc())
  343. finally:
  344. handlers = logger_main.handlers.copy()
  345. for h in handlers:
  346. logger_main.removeHandler(h)
  347. pool.terminate()