隐马尔科夫模型

特征向量:跟踪框位置相对轨迹中心的比值,角度,速度。

马尔科夫模型:

State Sequence, q1 q2 ...... qT

t个状态之间的转移可见,则这个时间序列的概率是πq1 × aq1q2 × ...... × aqT-1qT

隐马尔科夫模型:

状态不可见(隐藏),只能从观察值推测出,所以由观察值推测该时刻的状态有个观察值概率b.

πq1 × bq1( o1 ) × aq1q2 × bq2( o2 ) × ...... × aqT-1qT × bqT( oT ),

三个问题:

1。评价问题——前向后向算法

计算所有可能路径的概率。

前向:计算时向前运用结合律,意义??

2。解码问题

计算所有可能路径中概率最大的一条路径

3。学习问题

給定一個觀察序列o1 o2 ...... oT,更新ABΠ使得Evaluation Problem算得的機率盡量大。

程序:

解码问题:

double CHMM::Decode(vector<double*>& seq, vector<int>& state)
{
    // Viterbi
    int size = (int)seq.size();
    double* lastLogP = new double[m_stateNum];
    double* currLogP = new double[m_stateNum];
    int** path = new int*[size];
    int i,j,t;

    // Init
    path[0] = new int[m_stateNum];
    for ( i = 0; i < m_stateNum; i++)
    {
        currLogP[i] = LogProb(m_stateInit[i]) + 
            LogProb(m_stateModel[i]->GetProbability(seq[0]));
        path[0][i] = -1;
    }

    // Recursion
    for ( t = 1; t < size; t++)  //对每一个观测,求属于每个状态的当前最大累加概率
    {
        path[t] = new int[m_stateNum];
        double* temp = lastLogP;
        lastLogP = currLogP;
        currLogP = temp;

        for ( i = 0; i < m_stateNum; i++)
        {
            currLogP[i] = -1e308;
            // Searching the max for last state.
            for ( j = 0; j < m_stateNum; j++)
            {
                double l = lastLogP[j] + LogProb(m_stateTran[j][i]);
                if (l > currLogP[i])
                {
                    currLogP[i] = l;
                    path[t][i] = j;
                }
            }
            currLogP[i] += LogProb(m_stateModel[i]->GetProbability(seq[t]));
        }
    }

    // Termination
    int finalState = 0;
    double prob = -1e308;
    for ( i = 0; i < m_stateNum; i++)
    {
        if (currLogP[i] > prob)
        {
            prob = currLogP[i];
            finalState = i;
        }
    }

    // Decode
    state.push_back(finalState);
    for ( t = size - 2; t >=0; t--)
    {
        int stateIndex = path[t+1][state.back()];
        state.push_back(stateIndex);
    }

    // Reverse the state list
    reverse(state.begin(), state.end());

    // Clean up
    delete[] lastLogP;
    delete[] currLogP;
    for ( i = 0; i < size; i++)
    {
        delete[] path[i];
    }
    delete[] path;

    prob = exp(prob / size);
    return prob;
}

训练问题:

init:把所有样本的每个序列的特征值平均分给每个状态,然后用混合高斯模型表征每个状态。

train:先用decode解码,得到该序列一条概率最大的路径,对路径上所有出现的状态转移进行累积,最后两个状态之间的转移数除以该状态转移到其它所有状态移总数,得到的比值即为状态转移概率和初始状态概率。迭代直到误差小于0.001

/*    SampleFile: <size><dim><seq_size><seq_data>...<seq_size>...*/
void CHMM::Init(const char* sampleFileName)
{
    //--- Debug ---//
    //DumpSampleFile(sampleFileName);

    // Check the sample file
    ifstream sampleFile(sampleFileName, ios_base::binary);
    assert(sampleFile);

    int i,j;
    int size = 0;
    int dim = 0;
    sampleFile.read((char*)&size, sizeof(int));  //读样本数
    sampleFile.read((char*)&dim, sizeof(int));   //读取特征维数
    assert(size >= 3);
    assert(dim == m_stateModel[0]->GetDimNum());

    //这里为从左到右型,第一个状态的初始概率为0.5, 其他状态的初始概率之和为0.5,
    //每个状态到自身的转移概率为0.5, 到下一个状态的转移概率为0.5.
    //此处的初始化主要是对混合高斯模型进行初始化
    for ( i = 0; i < m_stateNum; i++)
    {
        // The initial probabilities
        if(i == 0)
            m_stateInit[i] = 0.5;
        else
            m_stateInit[i] = 0.5 / float(m_stateNum-1);

        // The transition probabilities
        for ( j = 0; j <= m_stateNum; j++)
        {
            if((i == j)||( j == i+1))
                m_stateTran[i][j] = 0.5;
        }
    }

    vector<double*> *gaussseq;
    gaussseq= new vector<double*>[m_stateNum];

    for ( i = 0; i < size; i++)//处理每个样本产生的特征序列
    {
        int seq_size = 0;
        sampleFile.read((char*)&seq_size, sizeof(int));  //序列的长度

        double r = float(seq_size)/float(m_stateNum); //每个状态有r个dim维的特征向量
        for ( j = 0; j < seq_size; j++)
        {
            double* x = new double[dim];
            sampleFile.read((char*)x, sizeof(double) * dim);
            //把特征序列平均分配给每个状态
            gaussseq[int(j/r)].push_back(x);
        }
    }

    char** stateFileName = new char*[m_stateNum];
    ofstream* stateFile = new ofstream[m_stateNum];
    int* stateDataSize = new int[m_stateNum];

    for ( i = 0; i < m_stateNum; i++)
    {
        stateFileName[i] = new char[20];
        ostrstream str(stateFileName[i], 20);
        str << "chmm_s" << i << ".tmp" << '';
    }
    //将每个状态的特征序列保存到文件中,并初始化GMM
    for ( i = 0; i < m_stateNum; i++)
    {
        stateFile[i].open(stateFileName[i], ios_base::binary);
        stateDataSize[i] = gaussseq[i].size();
        stateFile[i].write((char*)&stateDataSize[i], sizeof(int));
        stateFile[i].write((char*)&dim, sizeof(int));
        double* x = new double[dim];
        for( j = 0; j < stateDataSize[i]; j++)
        {
            x = (double*)gaussseq[i].at(j);
            stateFile[i].write((char*)x, sizeof(double) * dim);
        }
        delete x;
        stateFile[i].close();
        //使用Kmeans算法初始化状态的每个GMM
        m_stateModel[i]->Train_Lee(stateFileName[i],i);
        gaussseq[i].clear();
    }

    for ( i = 0; i < m_stateNum; i++)
        delete[] stateFileName[i];

    delete[] stateFileName;
    delete[] stateFile;
    delete[] stateDataSize;
    delete[] gaussseq;
}

/*    SampleFile: <size><dim><seq_size><seq_data>...<seq_size>...*/
void CHMM::Train(const char* sampleFileName)
{
    Init(sampleFileName);

    //--- Debug ---//
    DumpSampleFile(sampleFileName);

    // Check the sample file
    ifstream sampleFile(sampleFileName, ios_base::binary);
    assert(sampleFile);
    int i,j;

    int size = 0;
    int dim = 0;
    sampleFile.read((char*)&size, sizeof(int));
    sampleFile.read((char*)&dim, sizeof(int));
    assert(size >= 3);
    assert(dim == m_stateModel[0]->GetDimNum());

    // Buffer for new model
    int* stateInitNum = new int[m_stateNum];
    int** stateTranNum = new int*[m_stateNum];
    char** stateFileName = new char*[m_stateNum];
    ofstream* stateFile = new ofstream[m_stateNum];
    int* stateDataSize = new int[m_stateNum];

    for ( i = 0; i < m_stateNum; i++)
    {
        stateTranNum[i] = new int[m_stateNum + 1];
        stateFileName[i] = new char[20];
        ostrstream str(stateFileName[i], 20);
        str << "chmm_s" << i << ".tmp" << '';
    }

    bool loop = true;
    double currL = 0;
    double lastL = 0;
    int iterNum = 0; //迭代次数
    int unchanged = 0;
    vector<int> state;
    vector<double*> seq;

    while (loop)
    {
        lastL = currL;
        currL = 0;

        // Clear buffer and open temp data files
        for ( i = 0; i < m_stateNum; i++)
        {
            stateDataSize[i] = 0;
            stateFile[i].open(stateFileName[i], ios_base::binary);
            stateFile[i].write((char*)&stateDataSize[i], sizeof(int));
            stateFile[i].write((char*)&dim, sizeof(int));
            memset(stateTranNum[i], 0, sizeof(int) * (m_stateNum + 1));
        }
        memset(stateInitNum, 0, sizeof(int) * m_stateNum);

        // Predict: obtain the best path
        sampleFile.seekg(sizeof(int) * 2, ios_base::beg);
        for ( i = 0; i < size; i++)
        {
            int seq_size = 0;
            sampleFile.read((char*)&seq_size, sizeof(int));

            for ( j = 0; j < seq_size; j++)
            {
                double* x = new double[dim];
                sampleFile.read((char*)x, sizeof(double) * dim);
                seq.push_back(x);
            }

            currL += LogProb(Decode(seq, state)); //Viterbi解码

            stateInitNum[state[0]]++;
            for ( j = 0; j < seq_size; j++)
            {
                stateFile[state[j]].write((char*)seq[j], sizeof(double) * dim);
                stateDataSize[state[j]]++;
                if (j > 0)
                {
                    stateTranNum[state[j-1]][state[j]]++;
                }
            }
            stateTranNum[state[j-1]][m_stateNum]++; // Final state

            for ( j = 0; j < seq_size; j++)
            {
                delete[] seq[j];
            }
            state.clear();
            seq.clear();
        }
        currL /= size;

        // Close temp data files
        for ( i = 0; i < m_stateNum; i++)
        {
            stateFile[i].seekp(0, ios_base::beg);
            stateFile[i].write((char*)&stateDataSize[i], sizeof(int));
            stateFile[i].close();
        }

        // Reestimate: stateModel, stateInit, stateTran
        int count = 0;
        for ( j = 0; j < m_stateNum; j++)
        {
            if (stateDataSize[j] > m_stateModel[j]->GetMixNum() * 2)
            {
                m_stateModel[j]->DumpSampleFile(stateFileName[j]);
                m_stateModel[j]->Train_Lee(stateFileName[j],j);
            }
            count += stateInitNum[j];
        }
        for ( j = 0; j < m_stateNum; j++)
        {
            m_stateInit[j] = 1.0 * stateInitNum[j] / count;
        }
        for ( i = 0; i < m_stateNum; i++)
        {
            count = 0;
            for ( j = 0; j < m_stateNum + 1; j++)
            {
                count += stateTranNum[i][j];
            }
            if (count > 0)
            {
                for ( j = 0; j < m_stateNum + 1; j++)
                {
                    m_stateTran[i][j] = 1.0 * stateTranNum[i][j] / count;
                }
            }
        }
        // Terminal conditions
        iterNum++;
        unchanged = (currL - lastL < m_endError * fabs(lastL)) ? (unchanged + 1) : 0;
        if (iterNum >= m_maxIterNum || unchanged >= 3)
        {
            loop = false;
            ofstream fout("model.txt", ofstream::app);
            fout<<endl;
            for ( j = 0; j < m_stateNum; j++)
            {
                fout<<m_stateInit[j]<<" ";
            }
            fout<<endl;
            for ( i = 0; i < m_stateNum; i++)
            {
                for ( j = 0; j < m_stateNum + 1; j++)
                {
                    fout<<m_stateTran[i][j]<<" ";
                }
                fout<<endl;
            }
        }
        //DEBUG
        //cout << "Iter: " << iterNum << ", Average Log-Probability: " << currL << endl;
    }

    for ( i = 0; i < m_stateNum; i++)
    {
        delete[] stateTranNum[i];
        delete[] stateFileName[i];
    }
    delete[] stateTranNum;
    delete[] stateFileName;
    delete[] stateFile;
    delete[] stateInitNum;
    delete[] stateDataSize;
}
原文地址:https://www.cnblogs.com/jerrice/p/4354979.html