import pandas as pd
import numpy as np
import datetime
from random import shuffle
from sklearn.preprocessing import StandardScaler
from keras.layers import Activation,Dense,Input
from keras.layers.recurrent import GRU
from keras.models import Model
from keras.optimizers import adam_v2
from keras.layers import Dropout
import random

#### Process1 - Prediction - Model1+Model2 ###

# Step1 Features

# Model1 
def features1(dataset2):
    dataset2=dataset2.drop(['GSM信号','故障等级','故障代码','开关状态','绝缘电阻','外电压','总输出状态','上锁状态','加热状态','单体均衡状态','充电状态','SOH[%]','SOC[%]','总电流[A]'],axis=1,errors='ignore')
    cellvolt_list = [s for s in list(dataset2) if '单体电压' in s] 
    celltemp_name = [s for s in list(dataset2) if '温度' in s] 
    dataset2=dataset2.drop(celltemp_name,axis=1)
    dataset2['volt_max']=dataset2[cellvolt_list].max(axis=1)
    dataset2['volt_min']=dataset2[cellvolt_list].min(axis=1) 
    dataset2=dataset2.drop(cellvolt_list,axis=1)
    dataset2.reset_index(drop=True,inplace=True)
    return dataset2
# Model2
def features2(dataset2):
    dataset2=dataset2.drop(['GSM信号','故障等级','故障代码','开关状态','绝缘电阻','外电压','总输出状态','上锁状态','加热状态','单体均衡状态','充电状态','SOH[%]','SOC[%]','单体压差','总电压[V]'],axis=1,errors='ignore')
    cellvolt_list = [s for s in list(dataset2) if '单体电压' in s] 
    celltemp_name = [s for s in list(dataset2) if '单体温度' in s] 
    celltemp_name2 = [s for s in list(dataset2) if '其他温度' in s]
    dataset2=dataset2.drop(cellvolt_list+celltemp_name2,axis=1)
    dataset2['temp_max']=dataset2[celltemp_name].max(axis=1)
    dataset2['temp_min']=dataset2[celltemp_name].min(axis=1) 
    dataset2['temp_diff']=list(np.array(dataset2['temp_max'])-np.array(dataset2['temp_min']))
    dataset2=dataset2.drop(celltemp_name,axis=1)
    dataset2.reset_index(drop=True,inplace=True)
    return dataset2

# Step2 Splits
def split(df_bms_tot):
    df_bms_tot['split']=0
    for k in range(1,len(df_bms_tot)):
        timek=df_bms_tot.loc[k,'时间戳']
        timek1=df_bms_tot.loc[k-1,'时间戳']
        timek=datetime.datetime.strptime(timek,'%Y-%m-%d %H:%M:%S')     #type: datetime
        timek1=datetime.datetime.strptime(timek1,'%Y-%m-%d %H:%M:%S')
        deltatime=(timek-timek1).total_seconds()
        if (deltatime>600) | (df_bms_tot.loc[k,'sn']!=df_bms_tot.loc[k-1,'sn']):
            df_bms_tot.loc[k,'split']=df_bms_tot.loc[k-1,'split']+1
        else:
            df_bms_tot.loc[k,'split']=df_bms_tot.loc[k-1,'split']
    return df_bms_tot

# Step3 MakeDataset: TimeSeries
def makedataset(dataset):
    df_bms=pd.DataFrame()
    for split in list(set(dataset['split'])):
        set2=dataset[dataset['split']==split]
        set2.reset_index(drop=True,inplace=True)
        data_set=pd.DataFrame()
        start=set2.loc[0,'时间戳']
        end=set2.loc[len(set2)-1,'时间戳']
        data_set['Time']=pd.date_range(start=start, end=end, freq='S')  #每秒一条记录
        data_set['Time']=list(map(lambda x:str(x),list(data_set['Time'])))
        dfbms=pd.merge(data_set,set2,left_on='Time',right_on='时间戳',how='left')
        dfbms=dfbms.fillna(method='ffill')
        dfbms=dfbms.fillna(method='bfill')  
        dfbms=dfbms.drop(['时间戳'],axis=1)
        dfbms['Time']=list(map(lambda x:x[:18]+'0',list(dfbms['Time'])))
        dfbms.drop_duplicates(subset='Time',keep='last',inplace=True)
        df_bms=df_bms.append(dfbms)
        df_bms.reset_index(drop=True,inplace=True)
    return df_bms

# Step4 Scaler
def scaler_pred(df_bms,scaler):
    Xtest=df_bms.drop(['Time','sn','split'],axis=1)
    Xsc_colnames=list(Xtest.columns)
    Xtsc=scaler.transform(np.array(Xtest))
    Xtsc=pd.DataFrame(Xtsc)
    Xtsc.columns=Xsc_colnames
    return Xtsc

# Step5 MakeIndex
def make_index(train):
    indextr=[]
    for i in list(set(train['split'])):
        tr=train[train['split'] == i].index.tolist()
        indextr.append(min(tr))
    indextr=sorted(indextr)
    indextr.append(len(train))
    return indextr

# Step5 CreateWindows
def create_win_pred(X2,Xtest,index,time_steps=12): 
    conf=pd.DataFrame() 
    a=[]
    for k in range(1,len(index)):
        dataset=X2[index[k-1]:index[k]]
        dataset=dataset.reset_index(drop=True)
        dataset2=Xtest[index[k-1]:index[k]]
        dataset2=dataset2.reset_index(drop=True)
        if len(dataset)>time_steps:
            dataX = []
            win_step=[]
            for i in range(len(dataset)-time_steps): 
                win_step.append(i)
                #v1 = np.array(dataset.iloc[i:(i+time_steps)],dtype='float32')
                v1 = dataset.iloc[i:(i+time_steps)].values
                dataX.append(v1)
            test=dataset2.iloc[:len(dataset)-time_steps]
            dataX2=np.array(dataX,dtype='float32')
            conf=conf.append(test)
            a.append(dataX2)
    if len(a)>0:
        aa=np.vstack(a)
    else:
        aa=[]
    conf.reset_index(drop=True,inplace=True)
    return aa,conf

# Step6 Prediction
def prediction(model,cc,conf,col):
    predict_dd = model.predict(cc)  
    df_pred=pd.DataFrame(predict_dd)
    df_pred.columns=col
    df_pred2 = df_pred.idxmax(axis=1)
    conf['pred']=df_pred2
    return conf

# Step7 Output
def makeres(res,end_time):  
    df_res=pd.DataFrame(columns=['product_id','start_time','end_time','fault_class','update_time'])
    result_faults=res[res['pred']!='正常']
    list_faults=list(set(list(result_faults['pred'])))
    for fault in list_faults:
        res_faults=result_faults[result_faults['pred']==fault]
        res_faults.reset_index(drop=True,inplace=True)
        update_time=str(res_faults.loc[len(res_faults)-1,'Time'])
        end=datetime.datetime.strptime(str(res_faults.loc[len(res_faults)-1,'Time']),'%Y-%m-%d %H:%M:%S')
        end_time=datetime.datetime.strptime(str(end_time),'%Y-%m-%d %H:%M:%S')
        if (end_time-end).total_seconds()<900:
            res_faults.loc[len(res_faults)-1,'Time']='0000-00-00 00:00:00'
        df_res=df_res.append(pd.DataFrame({'product_id':[res_faults.loc[0,'sn']],'start_time':[str(res_faults.loc[0,'Time'])],
                        'end_time':[str(res_faults.loc[len(res_faults)-1,'Time'])],'fault_class':[res_faults.loc[0,'pred']],
                        'update_time':[update_time]}))
    return df_res

# Step7 Process
def pred(data_fea,model,scaler,col,end_time,time_steps):
    df_res=pd.DataFrame()
    fea=split(data_fea)
    f=makedataset(fea)
    sc=scaler_pred(f,scaler)
    index=make_index(f)
    dataX,pred=create_win_pred(sc,f,index,time_steps=time_steps)
    if len(dataX)>0:
        res=prediction(model,dataX,pred,col)
        df_res=makeres(res,end_time)
    return df_res

# Step8 Merge
def arrange(result,result_final):
    result.reset_index(drop=True,inplace=True)
    res_update=pd.DataFrame()
    res_new=result.copy()
    if len(result)>0:
        st=datetime.datetime.strptime(str(result.loc[0,'start_time']),'%Y-%m-%d %H:%M:%S')
        end=datetime.datetime.strptime(str(result_final['update_time']),'%Y-%m-%d %H:%M:%S')
        if (st-end).total_seconds()<3600:
            result_final['end_time']=result.loc[0,'end_time']
            result_final['update_time']=result.loc[0,'update_time']
            res_update=result_final.copy()
            res_new.drop(result.index,inplace=True)
        else:
            result_final['end_time']=result_final['update_time']
            res_update=result_final.copy()
            res_new.drop(result.index,inplace=True)
    else:
        result_final['end_time']=result_final['update_time']
        res_update=result_final.copy()
    return res_new,res_update

def arrange2(dataorg,df_res,time_stepsi):
    res_new=df_res.copy()
    res_update=pd.DataFrame()
    if len(dataorg)>0:
        res_new,res_update=arrange(df_res,dataorg)
    if len(res_new)>0:
        for i in range(len(res_new)):
            if res_new.loc[i,'end_time'] != '0000-00-00 00:00:00':
                st1=datetime.datetime.strptime(str(res_new.loc[i,'start_time']),'%Y-%m-%d %H:%M:%S')
                end1=datetime.datetime.strptime(str(res_new.loc[i,'end_time']),'%Y-%m-%d %H:%M:%S')
                if (end1-st1).total_seconds()<time_stepsi:
                    res_new.drop([i],axis=0,inplace=True)
    if len(res_update)>0:
        if res_update['end_time']!= '0000-00-00 00:00:00':
            st2=datetime.datetime.strptime(str(res_update['start_time']),'%Y-%m-%d %H:%M:%S')
            end2=datetime.datetime.strptime(str(res_update['end_time']),'%Y-%m-%d %H:%M:%S')
            res_update=pd.DataFrame(pd.DataFrame({'product_id':[res_update['product_id']],'start_time':[str(res_update['start_time'])],
                        'end_time':[str(res_update['end_time'])],'fault_class':[res_update['fault_class']],
                        'update_time':[res_update['update_time']]}))
            if (end2-st2).total_seconds()<time_stepsi:
                res_update=pd.DataFrame()
        else:
            res_update=pd.DataFrame(pd.DataFrame({'product_id':[res_update['product_id']],'start_time':[str(res_update['start_time'])],
                        'end_time':[str(res_update['end_time'])],'fault_class':[res_update['fault_class']],
                        'update_time':[res_update['update_time']]}))
    return res_new,res_update
#################################################################################################################################

#### Process1 - New Model ###

# Step1 Features Filtre
def features_filtre(dataset2,cols):
    dataset2=dataset2.drop(['GSM信号','故障等级','故障代码','开关状态','绝缘电阻','外电压','总输出状态','上锁状态','加热状态','单体均衡状态','充电状态','SOH[%]'],axis=1,errors='ignore')
    cellvolt_list = [s for s in list(dataset2) if '单体电压' in s] 
    celltemp_name = [s for s in list(dataset2) if '单体温度' in s] 
    celltemp_name2 = [s for s in list(dataset2) if '其他温度' in s]
    dataset2['volt_max']=dataset2[cellvolt_list].max(axis=1)
    dataset2['volt_min']=dataset2[cellvolt_list].min(axis=1)
    dataset2['volt_mean'] = round(dataset2[cellvolt_list].mean(axis=1),3)  #每行平均
    dataset2['volt_sigma'] =list(dataset2[cellvolt_list].apply(lambda x: np.std(x.values),axis=1))
    cell_volt_max =list(dataset2[cellvolt_list].apply(lambda x: np.argmax(x.values)+1,axis=1))
    cell_volt_min =list(dataset2[cellvolt_list].apply(lambda x: np.argmin(x.values)+1,axis=1))
    dataset2['mm_volt_cont'] = list(np.array(cell_volt_max) - np.array(cell_volt_min)) 
    dataset2['mm_volt_cont']=list(map(lambda x : 1 if (abs(x)==1) | (abs(x)==len(cellvolt_list)-1) else 0, list(dataset2['mm_volt_cont'])))
    #for k in range(len(dataset2)):
        #dataset2.loc[k,'mm_volt_cont']=1 if (abs(list(dataset2['mm_volt_cont'])[k])==1) | (abs(list(dataset2['mm_volt_cont'])[k])==len(cellvolt_list)-1) else 0 
    dataset2=dataset2.drop(cellvolt_list+celltemp_name2,axis=1)
    dataset2['temp_max']=dataset2[celltemp_name].max(axis=1)
    dataset2['temp_min']=dataset2[celltemp_name].min(axis=1) 
    dataset2['temp_diff']=list(np.array(dataset2['temp_max'])-np.array(dataset2['temp_min']))
    dataset2=dataset2.drop(celltemp_name,axis=1)
    datatest3=dataset2[cols]
    datatest3.reset_index(drop=True,inplace=True)
    return datatest3
    
# Step2 Data Filtre
def data_filtre(datatest3,col_key,compare,threshold):
    if compare==0:
        datatest4=datatest3[datatest3[col_key]==threshold]
    elif compare==1:
        datatest4=datatest3[datatest3[col_key]>threshold]
    else:
        datatest4=datatest3[datatest3[col_key]<threshold]
    datatest4.reset_index(drop=True,inplace=True)
    return datatest4

# Step3 Faults Pre-processing
def make_fault_set(dataset,cols,col_key,compare,threshold_filtre,fault_name):
    datatest3=features_filtre(dataset,cols)
    datatest4=data_filtre(datatest3,col_key,compare,threshold_filtre)
    df_tot=split(datatest4)
    df_bms=makedataset(df_tot)
    df_bms['fault_class']=fault_name
    return df_bms

# Step4 Normal Pre-processing
def normalset(df_bms,cols):
    df_bms.drop(['Unnamed: 0'],axis=1,inplace=True)
    nor_fea1=features_filtre(df_bms,cols)
    norfea1=split(nor_fea1)
    normalf1=makedataset(norfea1)
    normalf1['fault_class']='正常'
    return normalf1

def normalset2(df_bms1,df_bms2,df_bms3,df_bms4,df_bms5,df_bms6,cols):
    normalf1=normalset(df_bms1,cols)
    normalf2=normalset(df_bms2,cols)
    normalf3=normalset(df_bms3,cols)
    normalf4=normalset(df_bms4,cols)
    normalf5=normalset(df_bms5,cols)
    normalf6=normalset(df_bms6,cols)
    nor=pd.concat([normalf1,normalf2,normalf3,normalf4,normalf5,normalf6])
    nor.reset_index(drop=True,inplace=True)
    return nor

# Step5 Resample
def resample(nor,df_bms):
    if len(nor)>2*len(df_bms):
        sp=list(set(list(nor['split'])))
        sp_ran=random.sample(sp, k=int(len(sp)*(len(df_bms)/len(nor))))
        nor=nor[nor['split'].isin(sp_ran)]
        nor.reset_index(drop=True,inplace=True)
    if 2*len(nor)<len(df_bms):
        sp=list(set(list(df_bms['split'])))
        sp_ran=random.sample(sp, k=int(len(sp)*(len(nor)/len(df_bms))))
        df_bms=df_bms[df_bms['split'].isin(sp_ran)]
        df_bms.reset_index(drop=True,inplace=True)
    return nor,df_bms

# Step6 Shuffle Data
def shuffle_data(nor,dataset_faults):
    sn_nor=list(set(nor['sn']))
    sn_fau=list(set(dataset_faults['sn']))
    shuffle(sn_nor)
    shuffle(sn_fau)
    newtrain=pd.DataFrame()
    newtest=pd.DataFrame()
    for s1 in sn_nor[:int(0.8*len(sn_nor))]:
        nortrain=nor[nor['sn']==s1]
        nortrain.reset_index(drop=True,inplace=True)
        newtrain=newtrain.append(nortrain)
    for s2 in sn_nor[int(0.8*len(sn_nor)):]:
        nortest=nor[nor['sn']==s2]
        nortest.reset_index(drop=True,inplace=True)
        newtest=newtest.append(nortest)
    for s3 in sn_fau[:int(0.8*len(sn_fau))]:
        fautrain=dataset_faults[dataset_faults['sn']==s3]
        fautrain.reset_index(drop=True,inplace=True)
        newtrain=newtrain.append(fautrain)
    for s4 in sn_fau[int(0.8*len(sn_fau)):]:
        fautest=dataset_faults[dataset_faults['sn']==s4]
        fautest.reset_index(drop=True,inplace=True)
        newtest=newtest.append(fautest)
    newtrain.reset_index(drop=True,inplace=True)
    newtest.reset_index(drop=True,inplace=True)
    return newtrain,newtest

def shuffle_data2(dftrain):
    sp=list(set(dftrain['sn']))
    shuffle(sp)
    newtrain=pd.DataFrame()
    for s in sp:
        ntr=dftrain[dftrain['sn']==s]
        newtrain=newtrain.append(ntr)
    newtrain.reset_index(drop=True,inplace=True)
    return newtrain

# Step7 X & Y
def xy(train):
    Xtrain=train.drop(['fault_class','Time','sn','split'],axis=1)
    Ytrain=train[['fault_class']]          
    Ytrain2=pd.get_dummies(Ytrain,columns=['fault_class'],prefix_sep='_')
    cols=list(map(lambda x:x[12:],list(Ytrain2.columns)))
    return Xtrain,Ytrain,Ytrain2,cols

# Step8 Scaler 
def scaler_train(Xtrain):
    Xsc_colnames=list(Xtrain.columns)
    scaler=StandardScaler()
    scaler.fit(Xtrain)  #保存train_sc的均值和标准差
    Xsc=scaler.transform(np.array(Xtrain))
    Xsc=pd.DataFrame(Xsc)
    Xsc.columns=Xsc_colnames
    return Xsc,scaler

def scaler_test(Xtest,scaler):
    Xsc_colnames=list(Xtest.columns)
    Xtsc=scaler.transform(np.array(Xtest))
    Xtsc=pd.DataFrame(Xtsc)
    Xtsc.columns=Xsc_colnames
    return Xtsc

# Step9 Create windows 
def create_win_train(X2,Y2,index,time_steps=6):  
    a,b=[],[] 
    for k in range(1,len(index)):
        dataset=X2[index[k-1]:index[k]]
        dataset=dataset.reset_index(drop=True)
        datay=Y2[index[k-1]:index[k]]
        datay=datay.reset_index(drop=True)
        if len(dataset)>time_steps:
            dataX, dataY = [], []
            for i in range(len(dataset)-time_steps): 
                v1 = dataset.iloc[i:(i+time_steps)].values
                v2 = datay.iloc[i].values
                dataX.append(v1)
                dataY.append(v2)
            dataX2=np.array(dataX,dtype='float32')
            dataY2=np.array(dataY)
        else:
            continue
        a.append(dataX2)             
        b.append(dataY2)
    aa=np.vstack(a)
    bb=np.vstack(b)  
    return aa,bb

def create_win_test(X2,Y2,Xtest,index,time_steps=12):  
    a,b=[],[] 
    conf=pd.DataFrame()
    for k in range(1,len(index)):
        dataset=X2[index[k-1]:index[k]]
        dataset=dataset.reset_index(drop=True)
        datay=Y2[index[k-1]:index[k]]
        datay=datay.reset_index(drop=True)
        dataset2=Xtest[index[k-1]:index[k]]
        dataset2=dataset2.reset_index(drop=True)
        if len(dataset)>time_steps:
            dataX, dataY = [], []
            win_step=[]
            for i in range(len(dataset)-time_steps): 
                win_step.append(i)
                v1 = dataset.iloc[i:(i+time_steps)].values
                v2 = datay.iloc[i].values
                dataX.append(v1)
                dataY.append(v2)
            test=dataset2.iloc[:len(dataset)-time_steps]
            test['win']=win_step
            test=pd.merge(test,datay,left_index=True,right_index=True)
            dataX2=np.array(dataX,dtype='float32')
            dataY2=np.array(dataY)
        else:
            continue
        a.append(dataX2)             
        b.append(dataY2)
        conf=conf.append(test)
    aa=np.vstack(a)
    bb=np.vstack(b)
    conf.reset_index(drop=True,inplace=True)
    return aa,bb,conf 

# Step10 Create Model
def modelGRU(time_steps,nbr_features,nbr_neurons,nbr_class,Xwin,Ywin,Xtwin,Ytwin,batch_size,epochs,dropout,lr,activation,loss,metrics):
    time_steps=time_steps
    inputs = Input(shape=[time_steps,nbr_features])
    x = GRU(nbr_neurons, input_shape = (time_steps,nbr_features),return_sequences=False, return_state=False)(inputs)
    x = Dropout(dropout)(x)
    x = Dense(nbr_class)(x)
    x = Dropout(dropout)(x)
    x = Activation(activation)(x)
    LR = lr
    model = Model(inputs,x)
    adam = adam_v2.Adam(LR)
    model.compile(loss = loss,optimizer = adam,metrics = [metrics])
    model.fit(Xwin,Ywin,epochs=epochs,validation_data=(Xtwin,Ytwin),batch_size=batch_size,verbose=1,shuffle=True)
    return model

# Step11 Process
def pre_model(nor,df_bms,time_steps,nbr_features,nbr_neurons,nbr_class,batch_size,epochs,dropout,lr,activation,loss):
    nor,df_bms=resample(nor,df_bms)
    newtrain,newtest=shuffle_data(nor,df_bms)
    train_sh=shuffle_data2(newtrain)
    test_sh=shuffle_data2(newtest)
    Xtrain,Ytrain,Ytrain2,cols_train=xy(train_sh)
    Xtest,Ytest,Ytest2,cols_test=xy(test_sh)                           
    Xsc,scaler=scaler_train(Xtrain)
    Xtsc=scaler_test(Xtest,scaler)
    indextr=make_index(train_sh)
    indexte=make_index(test_sh)
    Xwin,Ywin=create_win_train(Xsc,Ytrain2,indextr,time_steps=time_steps)
    Xtwin,Ytwin,conf=create_win_test(Xtsc,Ytest2,test_sh,indexte,time_steps=time_steps)
    model=modelGRU(time_steps=time_steps,nbr_features=nbr_features,nbr_neurons=nbr_neurons,nbr_class=nbr_class,Xwin=Xwin,Ywin=Ywin,
                    Xtwin=Xtwin,Ytwin=Ytwin,batch_size=batch_size,epochs=epochs,dropout=dropout,lr=lr,activation=activation,
                    loss=loss,metrics='accuracy')
    loss,acc=model.evaluate(Xtwin,Ytwin)
    return scaler,model,acc,cols_train