Browse Source

faultclass train

zhuxi 2 years ago
parent
commit
b40d633012

+ 11 - 19
LIB/MIDDLE/FaultClass/V1_0_0/faultclass.py

@@ -176,14 +176,6 @@ def arrange(result,result_final,start_time):
             res_update=res_update.append(result0)
     return res_new,res_update
 
-def arrange2(dataorg,df_res,start_time,fault_name):
-    res_new=df_res.copy()
-    res_update=pd.DataFrame()
-    if len(dataorg)>0:
-        dataorg=dataorg[dataorg['fault_class']==fault_name]
-        res_new,res_update=arrange(df_res,dataorg,start_time)
-    return res_new,res_update
-
 # Step8 Process
 def pred(data_fea,model,scaler,col,end_time,time_steps):
     df_res=pd.DataFrame()
@@ -249,21 +241,21 @@ def make_fault_set(dataset,cols,col_key,compare,threshold_filtre,fault_name):
     return df_bms
 
 # Step4 Normal Pre-processing
-def normalset(df_bms):
-    df_bms['fault_class']='正常'
+def normalset(df_bms,cols):
     df_bms.drop(['Unnamed: 0'],axis=1,inplace=True)
-    nor_fea1=features_filtre(df_bms)
+    nor_fea1=features_filtre(df_bms,cols)
     norfea1=split(nor_fea1)
     normalf1=makedataset(norfea1)
+    normalf1['fault_class']='正常'
     return normalf1
 
-def normalset2(df_bms1,df_bms2,df_bms3,df_bms4,df_bms5,df_bms6):
-    normalf1=normalset(df_bms1)
-    normalf2=normalset(df_bms2)
-    normalf3=normalset(df_bms3)
-    normalf4=normalset(df_bms4)
-    normalf5=normalset(df_bms5)
-    normalf6=normalset(df_bms6)
+def normalset2(df_bms1,df_bms2,df_bms3,df_bms4,df_bms5,df_bms6,cols):
+    normalf1=normalset(df_bms1,cols)
+    normalf2=normalset(df_bms2,cols)
+    normalf3=normalset(df_bms3,cols)
+    normalf4=normalset(df_bms4,cols)
+    normalf5=normalset(df_bms5,cols)
+    normalf6=normalset(df_bms6,cols)
     nor=pd.concat([normalf1,normalf2,normalf3,normalf4,normalf5,normalf6])
     nor.reset_index(drop=True,inplace=True)
     return nor
@@ -420,7 +412,7 @@ def modelGRU(time_steps,nbr_features,nbr_neurons,nbr_class,Xwin,Ywin,Xtwin,Ytwin
     return model
 
 # Step11 Process
-def pre_model(nor,df_bms,time_steps,nbr_features,nbr_neurons,nbr_class,batch_size,epochs,dropout,lr,activation):
+def pre_model(nor,df_bms,time_steps,nbr_features,nbr_neurons,nbr_class,batch_size,epochs,dropout,lr,activation,loss):
     nor,df_bms=resample(nor,df_bms)
     newtrain,newtest=shuffle_data(nor,df_bms)
     train_sh=shuffle_data2(newtrain)

+ 42 - 0
LIB/MIDDLE/FaultClass/V1_0_0/main_input.py

@@ -0,0 +1,42 @@
+from sqlalchemy import create_engine
+from urllib import parse
+import pandas as pd
+import pymysql
+
+#用户输入参数
+fault_name='电压采样断线'
+cols=str(['时间戳','sn','单体压差','volt_max','volt_min','volt_mean','volt_sigma','mm_volt_cont'])
+col_key='mm_volt_cont'
+compare=0
+threshold_filtre=1
+
+time_steps=12
+nbr_features=6
+nbr_neurons=5
+nbr_class=2
+batch_size=100
+epochs=5
+dropout=0.5
+lr=1e-3
+activation='softmax'
+loss='categorical_crossentropy'
+
+threshold_accuracy=0.95
+
+#数据库配置
+host='rm-bp10j10qy42bzy0q77o.mysql.rds.aliyuncs.com'
+port=3306
+db='qx_cas'
+user='qx_algo_rw'
+password='qx@123456'
+
+db_res_engine = create_engine(
+    "mysql+pymysql://{}:{}@{}:{}/{}?charset=utf8".format(
+        user, parse.quote_plus(password), host, port, db
+    ))
+#mysql = pymysql.connect (host=host, user=user, password=password, port=port, database=db)
+input_param=pd.DataFrame({'fault_name':[fault_name],'cols':[cols],'col_key':[col_key],'compare':[compare],'threshold_filtre':[threshold_filtre],
+                        'time_steps':[time_steps],'nbr_features':[nbr_features],'nbr_neurons':[nbr_neurons],'nbr_class':[nbr_class],'batch_size':[batch_size],
+                        'epochs':[epochs],'dropout':[dropout],'lr':[lr],'activation':[activation],'loss':[loss],'threshold_accuracy':[threshold_accuracy]})
+input_param.to_sql("faultclass_input",con=db_res_engine, if_exists="append",index=False)
+#mysql.close()

+ 158 - 0
LIB/MIDDLE/FaultClass/V1_0_0/main_train.py

@@ -0,0 +1,158 @@
+from faultclass import *
+import pymysql
+import datetime
+import pandas as pd
+import datetime
+import pickle
+from LIB.BACKEND import DBManager
+dbManager = DBManager.DBManager()
+from LIB.MIDDLE.CellStateEstimation.Common import log
+mylog=log.Mylog('log.txt','error')
+mylog.logcfg()
+from sqlalchemy import create_engine
+from urllib import parse
+
+#故障
+fault_name='电压采样断线'
+
+#读取文件:正常数据
+
+df_bms1=pd.read_csv('LIB/MIDDLE/FaultClass/V1_0_0/data/dataset2.csv')
+df_bms2=pd.read_csv('LIB/MIDDLE/FaultClass/V1_0_0/data/dataset3.csv')
+df_bms3=pd.read_csv('LIB/MIDDLE/FaultClass/V1_0_0/data/dataset4.csv')
+df_bms4=pd.read_csv('LIB/MIDDLE/FaultClass/V1_0_0/data/dataset5.csv')
+df_bms5=pd.read_csv('LIB/MIDDLE/FaultClass/V1_0_0/data/dataset6.csv')
+df_bms6=pd.read_csv('LIB/MIDDLE/FaultClass/V1_0_0/data/dataset7.csv')
+
+#数据库配置
+host0='rm-bp10j10qy42bzy0q77o.mysql.rds.aliyuncs.com'
+port0=3306
+db0='qx_cas'
+user0='qx_algo_rw'
+password0='qx@123456'
+
+#读取结果库数据......................................................
+param='fault_name,cols,col_key,compare,threshold_filtre,time_steps,nbr_features,nbr_neurons,nbr_class,batch_size,epochs,dropout,lr,activation,loss,threshold_accuracy'
+tablename='faultclass_input'
+mysql = pymysql.connect (host=host0, user=user0, password=password0, port=port0, database=db0)
+cursor = mysql.cursor()
+sql =  "select {} from {} where fault_name='{}'".format(param,tablename,fault_name)
+cursor.execute(sql)
+res = cursor.fetchall()
+list_param= pd.DataFrame(res,columns=param.split(','))
+list_param.reset_index(drop=True,inplace=True)
+
+db_res_engine = create_engine(
+    "mysql+pymysql://{}:{}@{}:{}/{}?charset=utf8".format(
+        user0, parse.quote_plus(password0), host0, port0, db0
+    ))
+
+#用户输入参数
+cols=eval(list_param.loc[0,'cols'])
+col_key=list_param.loc[0,'col_key']
+compare=list_param.loc[0,'compare']
+threshold_filtre=list_param.loc[0,'threshold_filtre']
+
+time_steps=list_param.loc[0,'time_steps']
+nbr_features=list_param.loc[0,'nbr_features']
+nbr_neurons=list_param.loc[0,'nbr_neurons']
+nbr_class=list_param.loc[0,'nbr_class']
+batch_size=list_param.loc[0,'batch_size']
+epochs=list_param.loc[0,'epochs']
+dropout=list_param.loc[0,'dropout']
+lr=list_param.loc[0,'lr']
+activation=list_param.loc[0,'activation']
+loss=list_param.loc[0,'loss']
+
+threshold_accuracy=list_param.loc[0,'threshold_accuracy']
+
+#数据库配置
+host='rm-bp10j10qy42bzy0q77o.mysql.rds.aliyuncs.com'
+port=3306
+db='safety_platform'
+user='qx_read'
+password='Qx@123456'
+
+#读取故障结果库中当前故障......................................................
+param='start_time,end_time,product_id,code,info'
+tablename='all_fault_info'
+mysql = pymysql.connect (host=host, user=user, password=password, port=port, database=db)
+cursor = mysql.cursor()
+#sql =  "select %s from %s where end_time='0000-00-00 00:00:00'" %(param,tablename)
+sql =  "select %s from %s" %(param,tablename)
+cursor.execute(sql)
+res = cursor.fetchall()
+df_diag_ram= pd.DataFrame(res,columns=param.split(','))
+
+df_diag_ram.dropna(inplace=True)
+df_diag_ram.reset_index(drop=True,inplace=True)
+
+#数据库配置
+host2='rm-bp10j10qy42bzy0q7.mysql.rds.aliyuncs.com'
+port=3306
+db2='zhl_omp_v2'
+user2='zhl_omp'
+password='Qx@123456'
+
+#读取故障结果库中当前故障......................................................
+param='fault_time,fault_code,sn,child_tag,tag_type,update_time'
+tablename='t_cloud_control'
+mysql = pymysql.connect (host=host2, user=user2, password=password, port=port, database=db2)
+cursor = mysql.cursor()
+#sql =  "select %s from %s where end_time='0000-00-00 00:00:00'" %(param,tablename)
+sql =  "select %s from %s" %(param,tablename)
+cursor.execute(sql)
+res = cursor.fetchall()
+df_diag_ram2= pd.DataFrame(res,columns=param.split(','))
+
+df_diag_ram2.dropna(inplace=True)
+df_diag_ram2.reset_index(drop=True,inplace=True)
+
+#读取故障结果库中当前故障......................................................
+param='id,parent_id,name'
+tablename='t_child_problem'
+mysql = pymysql.connect (host=host2, user=user2, password=password, port=port, database=db2)
+cursor = mysql.cursor()
+#sql =  "select %s from %s where end_time='0000-00-00 00:00:00'" %(param,tablename)
+sql =  "select %s from %s" %(param,tablename)
+cursor.execute(sql)
+res = cursor.fetchall()
+df_diag_ram3= pd.DataFrame(res,columns=param.split(','))
+
+df_diag_ram3['id']=list(map(lambda x:str(x),list(df_diag_ram3['id'])))
+df_diag=pd.merge(df_diag_ram2,df_diag_ram3,how='left',left_on=['child_tag','tag_type'],right_on=['id','parent_id'])
+
+df_diag['fault_time']=list(map(lambda x:str(x),list(df_diag['fault_time'])))
+df_diag2=pd.merge(df_diag,df_diag_ram,how='left',left_on=['fault_time','sn','fault_code'],right_on=['start_time','product_id','code'])
+df_diag3=df_diag2.sort_values(by='update_time',ascending=False)
+
+df=df_diag2[df_diag2['name']==fault_name]
+df.reset_index(drop=True,inplace=True)
+
+dataset=pd.DataFrame()
+for k in range(len(df)):
+    try: 
+        sn =df.loc[k,'product_id']
+        start_time=str(df.loc[k,'start_time'])
+        end_time=df.loc[k,'end_time']
+        if end_time=='0000-00-00 00:00:00':
+            end_time=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')   #type: str
+        df_data = dbManager.get_data(sn=sn, start_time=start_time, end_time=end_time, data_groups=['bms'])
+        data_bms = df_data['bms']
+        data_bms['sn']=sn   
+        dataset=dataset.append(data_bms)
+        dataset.to_csv('LIB/MIDDLE/FaultClass/V1_0_0/data/fault_'+fault_name+'.csv')
+    except Exception as e:
+        print(repr(e))
+        mylog.logopt(sn,e)
+        pass 
+
+df_bms=make_fault_set(dataset,cols,col_key,compare,threshold_filtre,fault_name)
+nor=normalset2(df_bms1,df_bms2,df_bms3,df_bms4,df_bms5,df_bms6,cols)
+scaler,model,acc=pre_model(nor,df_bms,time_steps,nbr_features,nbr_neurons,nbr_class,batch_size,epochs,dropout,lr,activation,loss)
+df_acc=pd.DataFrame({'fault_name':[fault_name],'accuracy':[acc]})
+df_acc.to_sql("faultclass_output",con=db_res_engine, if_exists="append",index=False)
+
+if acc>threshold_accuracy:
+    model.save('models/model_'+fault_name+'.h5')
+    pickle.dump(scaler,open('models/scaler_'+fault_name+'.pkl','wb'))

BIN
LIB/MIDDLE/FaultClass/V1_0_0/models/model_电压采样断线.h5


BIN
LIB/MIDDLE/FaultClass/V1_0_0/models/scaler_电压采样断线.pkl