树(一)——线段树

问题

现在有1~30这30个数,数N被抽上的概率正比于1/sqrt(N+1),求满足这个概率分布的随机数发生器。

思路

第一,如何解决这个“概率正比”问题。

第二,如何产生满足条件的随机数。

第三,有更好的方法吗?

 

一、解决“概率正比”问题

在概率论中有一个概念叫作“几何概型”,举个例子,如何求圆的面积?

先画一个正方形记作A,再在A中画内切圆B。现在随机在A上面撒豆子,落在A上的豆子总数为AN,落在B上的为BN。

那么,当豆子总数趋向无穷大时,正方形与圆的面积比率趋向于AN/BN。

也就是说,样本数量足够大时,面积比(几何比)近似于概率

回到刚刚的问题,设数N被抽中的概率为PN,做一条直线,从原点出发,依次放上长为PN的线段LN(N从1到30),现在在L1~L30上随机撒豆子。假设LN上的豆子数为TN,那么TN之近似于PN之比,从而解决第一个问题。

二、产生满足要求的随机数

几何概型其实就是均匀分布要面积(长度、体积)分布的映射,而映射就是函数。

所以,这个随机数发生器的输入是随机数rand(),输出为题目所求,旨在求映射关系。

既然题目中的几何概型已被求出,那么产生相应随机数的基本步骤如下:

  1. set S=P1+P2+…P30,generate rand(0<=rand<S)
  2. for i=1 to 30 do:
  3.     if(rand<P(i)) return i
  4. end for

三、寻求优化的方法

上面的方法很简单,但是存在效率问题。

普通的if-else嵌套分支都可以转化为二叉查找树(if=1=left,else=0=right)。上述方法也如此,将其转化为二叉树之后,发现是一棵斜树

斜树有个缺点,就是作为查找树效率低下,因此存在优化的空间。

对斜树而言:

  1. 考虑最优情况,只判断一次,O(1)
  2. 考虑最坏情况,全部判断一次,O(n)
  3. 考虑平均情况,折半n/2,O(n)

而对经典的二叉平衡树而言(折半查找):

  1. 考虑最优情况,为深度,O(logn)
  2. 考虑最坏情况,为深度,O(logn)
  3. 考虑平均情况,为深度,O(logn)

由于是处理随机数,这里考虑平均情况,可知:二叉平衡树优于斜树

    但是,有一个问题:二叉平衡树是用来查找数的,不是用来查找区间的,所以这里用不了平衡树。

    因此,作为查找区间的一种数据结构——区间树(也称线段树),就应运而生了。

    我这里设计的区间树属于2-3树(键数<=2,值数<=3),结合了广义表的设计思想(结点和数据共用)。

    由于数据源是排序过的,所以可以直接采用分治法构建树。(这是由于分组后的数据仍保持有序)

    注:

    • 打印的树结构设计参考自系统自带tree.exe的结果
    • 键和值的关系为,val0<key0<=val1<key1<=val2
    • 树的生成,以及分治产生的多余结点处理问题(分配不均等)的解决详见源码

    image

    源码

    源码:itvtree.cpp

    #include "stdafx.h"
    #include <stdio.h>
    #include <math.h>
    #include <time.h>
    #include <vector>
    
    /************************************************************************/
    /*     线段树/区间树                                                    */
    /*     Interval Tree                                                    */
    /************************************************************************/
    
    typedef double TreeKeyType;
    
    //实质是2-3平衡树
    struct tree_node
    {
        unsigned char type[4];//只用到type[3],此是为了内存对齐
        TreeKeyType key[2];//2-键
        void* value[3];//3-值
    };
    
    #define TREE_NODE_UNUSED    0
    #define TREE_NODE_VALUE     1
    #define TREE_NODE_POINTER   2
    
    #define TREE_NOT_FOUND        -1
    
    #define TREE_PRINT_BLANK    0
    #define TREE_PRINT_TRUNK    1
    #define TREE_PRINT_BRANCH    2
    #define TREE_PRINT_BRANCHD    3
    
    
    //************************************
    // Method:    递归构造线段树
    // FullName:  tree_init_recursive
    // Access:    public 
    // Returns:   void*
    // Qualifier:
    // Parameter: TreeKeyType * s 数组指针(数据来源)
    // Parameter: int count 当前处理的范围
    // Parameter: int start 当前处理的位置
    // Parameter: unsigned char * type 修改结点类型
    //************************************
    void* tree_init_recursive(TreeKeyType* s, int count, int start, unsigned char* type)
    {
        //此递归调用函数的返回值可以为结点或真值,由type确定
        //树生成采用的是分治方法,当s数组分成3份,各自递归生成子树,再作为某个结点的孩子结点
    
        //注意:实际长度应为count+1
    
        tree_node* node;
    
        if (count == 0)//只有一个数时,只返回值类型,类比快排找中间轴(向中间逼近)
        {
            if (type)
                *type = TREE_NODE_VALUE;
            return (void*)start;//
        }
    
        node = new(tree_node);
        if (type)
            *type = TREE_NODE_POINTER;//当前处理长度大于1,需生成新结点,返回结点类型
    
        switch (count)
        {
        case 1://当前处理的长度为2,type长度为2
            node->type[0] = TREE_NODE_VALUE;
            node->type[1] = TREE_NODE_VALUE;
            node->type[2] = TREE_NODE_UNUSED;
            node->key[0] = s[0];//只需一个键,即大于和小于
            node->value[0] = (void*)(start);
            node->value[1] = (void*)(start + 1);
            break;
        case 2://当前处理的长度为3,type长度为3
            node->type[0] = TREE_NODE_VALUE;
            node->type[1] = TREE_NODE_VALUE;
            node->type[2] = TREE_NODE_VALUE;
            node->key[0] = s[0];
            node->key[1] = s[1];//需要两个键,即大于上界、区间内、小于下界
            node->value[0] = (void*)(start);
            node->value[1] = (void*)(start + 1);
            node->value[2] = (void*)(start + 2);
            break;
        default://处理的长度大于3,采用分治
            int a = count / 3;//三等分
            switch (count % 3)//处理三等分多余的数
            {
            case 0://多一个
                //分组:
                //1)start+[0]->start+[a-1]
                //2)start+[a]->start+[2a]
                //3)start+[2a+1]->start+[3a]
                //长度:
                //1)a-1
                //2)a
                //3)a-1
                //键:
                //1)a-1
                //2)2a
                node->key[0] = s[a - 1];
                node->key[1] = s[a * 2];
                node->value[0] = tree_init_recursive(&s[0], a - 1, start, &node->type[0]);
                node->value[1] = tree_init_recursive(&s[a], a, start + a, &node->type[1]);
                node->value[2] = tree_init_recursive(&s[a * 2 + 1], a - 1, start + a * 2 + 1, &node->type[2]);
                break;
            case 1://多两个
                //分组:
                //1)start+[0]->start+[a-1]
                //2)start+[a]->start+[2a+1]
                //3)start+[2a+2]->start+[3a+1]
                //长度:
                //1)a-1
                //2)a+1
                //3)a-1
                //键:
                //1)a-1
                //2)2a+1
                node->key[0] = s[a - 1];
                node->key[1] = s[a * 2 + 1];
                node->value[0] = tree_init_recursive(&s[0], a - 1, start, &node->type[0]);
                node->value[1] = tree_init_recursive(&s[a], a + 1, start + a, &node->type[1]);
                node->value[2] = tree_init_recursive(&s[a * 2 + 2], a - 1, start + a * 2 + 2, &node->type[2]);
                break;
            case 2://不多
                //分组:
                //1)start+[0]->start+[a]
                //2)start+[a+1]->start+[2a+1]
                //3)start+[2a+2]->start+[3a+2]
                //长度:
                //1)a
                //2)a
                //3)a
                //键:
                //1)a
                //2)2a+1
                node->key[0] = s[a];
                node->key[1] = s[a * 2 + 1];
                node->value[0] = tree_init_recursive(&s[0], a, start, &node->type[0]);
                node->value[1] = tree_init_recursive(&s[a + 1], a, start + a + 1, &node->type[1]);
                node->value[2] = tree_init_recursive(&s[a * 2 + 2], a, start + a * 2 + 2, &node->type[2]);
                break;
            }
        }
    
        return (void*)node;
    }
    
    tree_node* tree_init(TreeKeyType* s, int count, int start)
    {
        //初始化
        //给定线段各点坐标,构建树
        return (tree_node*)tree_init_recursive(s, count, start, NULL);
    }
    
    int tree_find_recursive(tree_node* node, TreeKeyType s)
    {
        //当前结点类型为值就直接返回,否则递归调用
        if (s < node->key[0])//小于左键,双键小于或单键小于,找第一值
        {
            if (node->type[0] == TREE_NODE_VALUE)
                return (int)node->value[0];
            else if (node->type[0] == TREE_NODE_POINTER)
                return (int)tree_find_recursive((tree_node*)node->value[0], s);
            else
                return TREE_NOT_FOUND;
        }
        if (node->type[2] == TREE_NODE_UNUSED || s < node->key[1])//双键区间内部或单键大于,找第二值
        {
            if (node->type[1] == TREE_NODE_VALUE)
                return (int)node->value[1];
            else if (node->type[1] == TREE_NODE_POINTER)
                return (int)tree_find_recursive((tree_node*)node->value[1], s);
            else
                return TREE_NOT_FOUND;
        }
        {//双键大于,找第三值
            if (node->type[2] == TREE_NODE_VALUE)
                return (int)node->value[2];
            else if (node->type[2] == TREE_NODE_POINTER)
                return (int)tree_find_recursive((tree_node*)node->value[2], s);
            else
                return TREE_NOT_FOUND;
        }
    }
    
    int tree_find(tree_node* node, TreeKeyType s)
    {
        if (!node) return TREE_NOT_FOUND;
        return tree_find_recursive(node, s);
    }
    
    int tree_size_recursive(tree_node* node)
    {
        int size = 1;
        if (node->type[0] == TREE_NODE_POINTER)
            size += tree_size_recursive((tree_node*)node->value[0]);
        if (node->type[1] == TREE_NODE_POINTER)
            size += tree_size_recursive((tree_node*)node->value[1]);
        if (node->type[2] == TREE_NODE_POINTER)
            size += tree_size_recursive((tree_node*)node->value[2]);
        return size;
    }
    
    int tree_size(tree_node* node)
    {
        if (!node) return 0;
        return tree_size_recursive(node);
    }
    
    void tree_destroy_recursive(tree_node* node)
    {
        if (node->type[0] == TREE_NODE_POINTER)
            tree_destroy_recursive((tree_node*)node->value[0]);
        if (node->type[1] == TREE_NODE_POINTER)
            tree_destroy_recursive((tree_node*)node->value[1]);
        if (node->type[2] == TREE_NODE_POINTER)
            tree_destroy_recursive((tree_node*)node->value[2]);
        delete (node);
    }
    
    void tree_destroy(tree_node* node)
    {
        if (!node) return;
        tree_destroy_recursive(node);
    }
    
    void tree_print_helper(const std::vector<int>& mark)
    {
        for (auto m : mark)
        {
            switch (m)
            {
            case TREE_PRINT_BLANK:
                printf("        ");
                break;
            case TREE_PRINT_TRUNK:
                printf("");
                break;
            case TREE_PRINT_BRANCH:
                printf("├───");
                break;
            case TREE_PRINT_BRANCHD:
                printf("└───");
                break;
            default:
                break;
            }        
        }    
    }
    
    void tree_print_recursive(tree_node* node, std::vector<int>& mark)
    {
        if (node == NULL) return;
        tree_print_helper(mark);
        int last_branch = node->type[2] != TREE_NODE_UNUSED ? 2 :
            node->type[1] != TREE_NODE_UNUSED ? 1 : 0;
        if (last_branch == 2)
            printf("<%f, %f>
    ", node->key[0], node->key[1]);
        else
            printf("<%f>
    ", node->key[0]);
        int last_mark = *mark.rbegin();
        if (last_mark != TREE_PRINT_BLANK)
        {
            mark.pop_back();
            mark.push_back(last_mark == TREE_PRINT_BRANCHD ? TREE_PRINT_BLANK : TREE_PRINT_TRUNK);
        }
        for (int i = 0; i <= last_branch; i++)
        {
            if (last_branch == i)
            {
                mark.push_back(TREE_PRINT_BRANCHD);
            }
            else
            {
                mark.push_back(TREE_PRINT_BRANCH);
            }
            if (node->type[i] == TREE_NODE_POINTER)
            {
                tree_print_recursive((tree_node*)node->value[i], mark);
            }
            else if (node->type[i] == TREE_NODE_VALUE)
            {
                tree_print_helper(mark);
                printf("<%d>
    ", (int)node->value[i]);
            }
            if (last_branch != i)
            {
                mark.pop_back();
            }
        }
        mark.pop_back();
    }
    
    void tree_print(tree_node* node)
    {
        if (!node) return;
        std::vector<int> mark;
        mark.push_back(TREE_PRINT_BLANK);
        tree_print_recursive(node, mark);
    }
    
    int main(int argc, char* argv[])
    {
        //要求:对第i的页面的访问概率正比于1/sqrt(i+1)
        const int count = 30;
        const int tests = 10;
        TreeKeyType* sum = new TreeKeyType[count];
        sum[0] = 1;
        //sum[0]=1
        //sum[1]=sum[0]+1/sqrt(2)
        //sum[2]=sum[1]+1/sqrt(3)
        //...
        //sum[n-1]=sum[n-2]+1/sqrt(n)
        for (int i = 1; i < count; i++)
        {
            sum[i] = sum[i - 1] + (double)(1 / sqrt(i + 1));
        }
        TreeKeyType MaxRandValue = sum[count - 1] - 0.00001;
        tree_node* search_node = tree_init(sum, count, 0);
        printf("Search node size: %d
    ", tree_size(search_node) * sizeof(search_node));
        printf("========== tree ==========
    ");
        tree_print(search_node);
        printf("========== find ==========
    ");
        srand((unsigned int)time(NULL));
        for (int i = 0; i < tests; i++)
        {
            TreeKeyType rnd = rand() / double(RAND_MAX) * MaxRandValue;
            printf("key: %f, val: %d
    ", rnd, tree_find(search_node, rnd));
        }
        delete[] (sum);
        return 0;
    }
    原文地址:https://www.cnblogs.com/bajdcc/p/4777879.html