engine.go

package engine

import (
    "fmt"
    "github.com/huichen/murmur"
    "github.com/huichen/sego"
    "github.com/huichen/wukong/core"
    "github.com/huichen/wukong/storage"
    "github.com/huichen/wukong/types"
    "github.com/huichen/wukong/utils"
    "log"
    "os"
    "runtime"
    "sort"
    "strconv"
    "sync/atomic"
    "time"
)

const (
    NumNanosecondsInAMillisecond = 1000000
    PersistentStorageFilePrefix  = "wukong"
)

type Engine struct {
    // 计数器,用来统计有多少文档被索引等信息
    numDocumentsIndexed      uint64
    numDocumentsRemoved      uint64
    numDocumentsForceUpdated uint64
    numIndexingRequests      uint64
    numRemovingRequests      uint64
    numForceUpdatingRequests uint64
    numTokenIndexAdded       uint64
    numDocumentsStored       uint64

    // 记录初始化参数
    initOptions types.EngineInitOptions
    initialized bool

    indexers   []core.Indexer
    rankers    []core.Ranker
    segmenter  sego.Segmenter
    stopTokens StopTokens
    dbs        []storage.Storage

    // 建立索引器使用的通信通道
    segmenterChannel         chan segmenterRequest
    indexerAddDocChannels    []chan indexerAddDocumentRequest
    indexerRemoveDocChannels []chan indexerRemoveDocRequest
    rankerAddDocChannels     []chan rankerAddDocRequest

    // 建立排序器使用的通信通道
    indexerLookupChannels   []chan indexerLookupRequest
    rankerRankChannels      []chan rankerRankRequest
    rankerRemoveDocChannels []chan rankerRemoveDocRequest

    // 建立持久存储使用的通信通道
    persistentStorageIndexDocumentChannels []chan persistentStorageIndexDocumentRequest
    persistentStorageInitChannel           chan bool
}

func (engine *Engine) Init(options types.EngineInitOptions) {
    // 将线程数设置为CPU数
    runtime.GOMAXPROCS(runtime.NumCPU())

    // 初始化初始参数
    if engine.initialized {
        log.Fatal("请勿重复初始化引擎")
    }
    options.Init()
    engine.initOptions = options
    engine.initialized = true

    if !options.NotUsingSegmenter {
        // 载入分词器词典
        engine.segmenter.LoadDictionary(options.SegmenterDictionaries)

        // 初始化停用词
        engine.stopTokens.Init(options.StopTokenFile)
    }

    // 初始化索引器和排序器
    for shard := 0; shard < options.NumShards; shard++ {
        engine.indexers = append(engine.indexers, core.Indexer{})
        engine.indexers[shard].Init(*options.IndexerInitOptions)

        engine.rankers = append(engine.rankers, core.Ranker{})
        engine.rankers[shard].Init()
    }

    // 初始化分词器通道
    engine.segmenterChannel = make(
        chan segmenterRequest, options.NumSegmenterThreads)

    // 初始化索引器通道
    engine.indexerAddDocChannels = make(
        []chan indexerAddDocumentRequest, options.NumShards)
    engine.indexerRemoveDocChannels = make(
        []chan indexerRemoveDocRequest, options.NumShards)
    engine.indexerLookupChannels = make(
        []chan indexerLookupRequest, options.NumShards)
    for shard := 0; shard < options.NumShards; shard++ {
        engine.indexerAddDocChannels[shard] = make(
            chan indexerAddDocumentRequest,
            options.IndexerBufferLength)
        engine.indexerRemoveDocChannels[shard] = make(
            chan indexerRemoveDocRequest,
            options.IndexerBufferLength)
        engine.indexerLookupChannels[shard] = make(
            chan indexerLookupRequest,
            options.IndexerBufferLength)
    }

    // 初始化排序器通道
    engine.rankerAddDocChannels = make(
        []chan rankerAddDocRequest, options.NumShards)
    engine.rankerRankChannels = make(
        []chan rankerRankRequest, options.NumShards)
    engine.rankerRemoveDocChannels = make(
        []chan rankerRemoveDocRequest, options.NumShards)
    for shard := 0; shard < options.NumShards; shard++ {
        engine.rankerAddDocChannels[shard] = make(
            chan rankerAddDocRequest,
            options.RankerBufferLength)
        engine.rankerRankChannels[shard] = make(
            chan rankerRankRequest,
            options.RankerBufferLength)
        engine.rankerRemoveDocChannels[shard] = make(
            chan rankerRemoveDocRequest,
            options.RankerBufferLength)
    }

    // 初始化持久化存储通道
    if engine.initOptions.UsePersistentStorage {
        engine.persistentStorageIndexDocumentChannels =
            make([]chan persistentStorageIndexDocumentRequest,
                engine.initOptions.PersistentStorageShards)
        for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
            engine.persistentStorageIndexDocumentChannels[shard] = make(
                chan persistentStorageIndexDocumentRequest)
        }
        engine.persistentStorageInitChannel = make(
            chan bool, engine.initOptions.PersistentStorageShards)
    }

    // 启动分词器
    for iThread := 0; iThread < options.NumSegmenterThreads; iThread++ {
        go engine.segmenterWorker()
    }

    // 启动索引器和排序器
    for shard := 0; shard < options.NumShards; shard++ {
        go engine.indexerAddDocumentWorker(shard)
        go engine.indexerRemoveDocWorker(shard)
        go engine.rankerAddDocWorker(shard)
        go engine.rankerRemoveDocWorker(shard)

        for i := 0; i < options.NumIndexerThreadsPerShard; i++ {
            go engine.indexerLookupWorker(shard)
        }
        for i := 0; i < options.NumRankerThreadsPerShard; i++ {
            go engine.rankerRankWorker(shard)
        }
    }

    // 启动持久化存储工作协程
    if engine.initOptions.UsePersistentStorage {
        err := os.MkdirAll(engine.initOptions.PersistentStorageFolder, 0700)
        if err != nil {
            log.Fatal("无法创建目录", engine.initOptions.PersistentStorageFolder)
        }

        // 打开或者创建数据库
        engine.dbs = make([]storage.Storage, engine.initOptions.PersistentStorageShards)
        for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
            dbPath := engine.initOptions.PersistentStorageFolder + "/" + PersistentStorageFilePrefix + "." + strconv.Itoa(shard)
            db, err := storage.OpenStorage(dbPath)
            if db == nil || err != nil {
                log.Fatal("无法打开数据库", dbPath, ": ", err)
            }
            engine.dbs[shard] = db
        }

        // 从数据库中恢复
        for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
            go engine.persistentStorageInitWorker(shard)
        }

        // 等待恢复完成
        for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
            <-engine.persistentStorageInitChannel
        }
        for {
            runtime.Gosched()
            if engine.numIndexingRequests == engine.numDocumentsIndexed {
                break
            }
        }

        // 关闭并重新打开数据库
        for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
            engine.dbs[shard].Close()
            dbPath := engine.initOptions.PersistentStorageFolder + "/" + PersistentStorageFilePrefix + "." + strconv.Itoa(shard)
            db, err := storage.OpenStorage(dbPath)
            if db == nil || err != nil {
                log.Fatal("无法打开数据库", dbPath, ": ", err)
            }
            engine.dbs[shard] = db
        }

        for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
            go engine.persistentStorageIndexDocumentWorker(shard)
        }
    }

    atomic.AddUint64(&engine.numDocumentsStored, engine.numIndexingRequests)
}

// 将文档加入索引
//
// 输入参数:
//  docId          标识文档编号,必须唯一,docId == 0 表示非法文档(用于强制刷新索引),[1, +oo) 表示合法文档
//  data          见DocumentIndexData注释
//  forceUpdate 是否强制刷新 cache,如果设为 true,则尽快添加到索引,否则等待 cache 满之后一次全量添加
//
// 注意:
//      1. 这个函数是线程安全的,请尽可能并发调用以提高索引速度
//      2. 这个函数调用是非同步的,也就是说在函数返回时有可能文档还没有加入索引中,因此
//         如果立刻调用Search可能无法查询到这个文档。强制刷新索引请调用FlushIndex函数。
func (engine *Engine) IndexDocument(docId uint64, data types.DocumentIndexData, forceUpdate bool) {
    engine.internalIndexDocument(docId, data, forceUpdate)

    hash := murmur.Murmur3([]byte(fmt.Sprint("%d", docId))) % uint32(engine.initOptions.PersistentStorageShards)
    if engine.initOptions.UsePersistentStorage && docId != 0 {
        engine.persistentStorageIndexDocumentChannels[hash] <- persistentStorageIndexDocumentRequest{docId: docId, data: data}
    }
}

func (engine *Engine) internalIndexDocument(
    docId uint64, data types.DocumentIndexData, forceUpdate bool) {
    if !engine.initialized {
        log.Fatal("必须先初始化引擎")
    }

    if docId != 0 {
        atomic.AddUint64(&engine.numIndexingRequests, 1)
    }
    if forceUpdate {
        atomic.AddUint64(&engine.numForceUpdatingRequests, 1)
    }
    hash := murmur.Murmur3([]byte(fmt.Sprint("%d%s", docId, data.Content)))
    engine.segmenterChannel <- segmenterRequest{
        docId: docId, hash: hash, data: data, forceUpdate: forceUpdate}
}

// 将文档从索引中删除
//
// 输入参数:
//  docId          标识文档编号,必须唯一,docId == 0 表示非法文档(用于强制刷新索引),[1, +oo) 表示合法文档
//  forceUpdate 是否强制刷新 cache,如果设为 true,则尽快删除索引,否则等待 cache 满之后一次全量删除
//
// 注意:
//      1. 这个函数是线程安全的,请尽可能并发调用以提高索引速度
//      2. 这个函数调用是非同步的,也就是说在函数返回时有可能文档还没有加入索引中,因此
//         如果立刻调用Search可能无法查询到这个文档。强制刷新索引请调用FlushIndex函数。
func (engine *Engine) RemoveDocument(docId uint64, forceUpdate bool) {
    if !engine.initialized {
        log.Fatal("必须先初始化引擎")
    }

    if docId != 0 {
        atomic.AddUint64(&engine.numRemovingRequests, 1)
    }
    if forceUpdate {
        atomic.AddUint64(&engine.numForceUpdatingRequests, 1)
    }
    for shard := 0; shard < engine.initOptions.NumShards; shard++ {
        engine.indexerRemoveDocChannels[shard] <- indexerRemoveDocRequest{docId: docId, forceUpdate: forceUpdate}
        if docId == 0 {
            continue
        }
        engine.rankerRemoveDocChannels[shard] <- rankerRemoveDocRequest{docId: docId}
    }

    if engine.initOptions.UsePersistentStorage && docId != 0 {
        // 从数据库中删除
        hash := murmur.Murmur3([]byte(fmt.Sprint("%d", docId))) % uint32(engine.initOptions.PersistentStorageShards)
        go engine.persistentStorageRemoveDocumentWorker(docId, hash)
    }
}

// 查找满足搜索条件的文档,此函数线程安全
func (engine *Engine) Search(request types.SearchRequest) (output types.SearchResponse) {
    if !engine.initialized {
        log.Fatal("必须先初始化引擎")
    }

    var rankOptions types.RankOptions
    if request.RankOptions == nil {
        rankOptions = *engine.initOptions.DefaultRankOptions
    } else {
        rankOptions = *request.RankOptions
    }
    if rankOptions.ScoringCriteria == nil {
        rankOptions.ScoringCriteria = engine.initOptions.DefaultRankOptions.ScoringCriteria
    }

    // 收集关键词
    tokens := []string{}
    if request.Text != "" {
        querySegments := engine.segmenter.Segment([]byte(request.Text))
        for _, s := range querySegments {
            token := s.Token().Text()
            if !engine.stopTokens.IsStopToken(token) {
                tokens = append(tokens, s.Token().Text())
            }
        }
    } else {
        for _, t := range request.Tokens {
            tokens = append(tokens, t)
        }
    }

    // 建立排序器返回的通信通道
    rankerReturnChannel := make(
        chan rankerReturnRequest, engine.initOptions.NumShards)

    // 生成查找请求
    lookupRequest := indexerLookupRequest{
        countDocsOnly:       request.CountDocsOnly,
        tokens:              tokens,
        labels:              request.Labels,
        docIds:              request.DocIds,
        options:             rankOptions,
        rankerReturnChannel: rankerReturnChannel,
        orderless:           request.Orderless,
    }

    // 向索引器发送查找请求
    for shard := 0; shard < engine.initOptions.NumShards; shard++ {
        engine.indexerLookupChannels[shard] <- lookupRequest
    }

    // 从通信通道读取排序器的输出
    numDocs := 0
    rankOutput := types.ScoredDocuments{}
    timeout := request.Timeout
    isTimeout := false
    if timeout <= 0 {
        // 不设置超时
        for shard := 0; shard < engine.initOptions.NumShards; shard++ {
            rankerOutput := <-rankerReturnChannel
            if !request.CountDocsOnly {
                for _, doc := range rankerOutput.docs {
                    rankOutput = append(rankOutput, doc)
                }
            }
            numDocs += rankerOutput.numDocs
        }
    } else {
        // 设置超时
        deadline := time.Now().Add(time.Nanosecond * time.Duration(NumNanosecondsInAMillisecond*request.Timeout))
        for shard := 0; shard < engine.initOptions.NumShards; shard++ {
            select {
            case rankerOutput := <-rankerReturnChannel:
                if !request.CountDocsOnly {
                    for _, doc := range rankerOutput.docs {
                        rankOutput = append(rankOutput, doc)
                    }
                }
                numDocs += rankerOutput.numDocs
            case <-time.After(deadline.Sub(time.Now())):
                isTimeout = true
                break
            }
        }
    }

    // 再排序
    if !request.CountDocsOnly && !request.Orderless {
        if rankOptions.ReverseOrder {
            sort.Sort(sort.Reverse(rankOutput))
        } else {
            sort.Sort(rankOutput)
        }
    }

    // 准备输出
    output.Tokens = tokens
    // 仅当CountDocsOnly为false时才充填output.Docs
    if !request.CountDocsOnly {
        if request.Orderless {
            // 无序状态无需对Offset截断
            output.Docs = rankOutput
        } else {
            var start, end int
            if rankOptions.MaxOutputs == 0 {
                start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
                end = len(rankOutput)
            } else {
                start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
                end = utils.MinInt(start+rankOptions.MaxOutputs, len(rankOutput))
            }
            output.Docs = rankOutput[start:end]
        }
    }
    output.NumDocs = numDocs
    output.Timeout = isTimeout
    return
}

// 阻塞等待直到所有索引添加完毕
func (engine *Engine) FlushIndex() {
    for {
        runtime.Gosched()
        if engine.numIndexingRequests == engine.numDocumentsIndexed &&
            engine.numRemovingRequests*uint64(engine.initOptions.NumShards) == engine.numDocumentsRemoved &&
            (!engine.initOptions.UsePersistentStorage || engine.numIndexingRequests == engine.numDocumentsStored) {
            // 保证 CHANNEL  REQUESTS 全部被执行完
            break
        }
    }
    // 强制更新,保证其为最后的请求
    engine.IndexDocument(0, types.DocumentIndexData{}, true)
    for {
        runtime.Gosched()
        if engine.numForceUpdatingRequests*uint64(engine.initOptions.NumShards) == engine.numDocumentsForceUpdated {
            return
        }
    }
}

// 关闭引擎
func (engine *Engine) Close() {
    engine.FlushIndex()
    if engine.initOptions.UsePersistentStorage {
        for _, db := range engine.dbs {
            db.Close()
        }
    }
}

// 从文本hash得到要分配到的shard
func (engine *Engine) getShard(hash uint32) int {
    return int(hash - hash/uint32(engine.initOptions.NumShards)*uint32(engine.initOptions.NumShards))
}

原文地址:https://www.cnblogs.com/zhangboyu/p/7461669.html