golang实现mysql udf

UDF(user-defined function)

当mysql提供的内置函数(count,min,max等)无法满足需求时,udf用于扩展自定义函数,满足特定查询需求。
在这里,假定一种db应用场景,有一个 t_quest_db的table,有2个字段 roleid(INT), data(blob)。data为自己编码的二进制数据字段,当想要通过select查询出肉眼可见的内容时,就需要用到UDF了。

主要代码实现

示例代码用到cgo,附上完整示例代码:https://files.cnblogs.com/files/fitness/udftest.zip

// 本文件实现t_quest_db的data字段的序列化和反序列化

package udfdll

import (
	"bytes"
	"encoding/gob"
)

type PlayerDoingQuestDbData struct {
	TaskId           int
	TaskDbId         int64
	AcceptTime       int32
	ProgressMap      map[int]string
	IsProgressFinish bool
}

type PlayerDoneQuestDbData struct {
	TaskId            int
	LastDoneTimestamp int32 // 上次完成任务的时间戳
	PeriodDoneTimes   int   // 任务周期内完成的次数(天/周)
}

// 这个结构存db, 对应t_quest_db的data字段
type PlayerAllQuestDbData struct {
	DoneTaskMap  map[int]*PlayerDoneQuestDbData
	DoingTaskMap map[int64]*PlayerDoingQuestDbData
}

// 把this编码为二进制,存入data字段
func (this *PlayerAllQuestDbData) ToDbBlob() ([]byte, error) {
	buf := bytes.NewBuffer([]byte{})
	enc := gob.NewEncoder(buf)
	if err := enc.Encode(this); err != nil {
		return nil, err
	}

	return buf.Bytes(), nil
}

// 解析data字段
func (this *PlayerAllQuestDbData) FromDbBlob(data []byte) error {
	if len(data) > 0 {
		dec := gob.NewDecoder(bytes.NewBuffer(data))
		if err := dec.Decode(this); err != nil {
			return err
		}
	}

	if this.DoneTaskMap == nil {
		this.DoneTaskMap = make(map[int]*PlayerDoneQuestDbData)
	}

	if this.DoingTaskMap == nil {
		this.DoingTaskMap = make(map[int64]*PlayerDoingQuestDbData)
	}

	return nil
}

// 本文件实现UDF扩展

package main

/*
#cgo CFLAGS: -Iinclude
#include <mysql.h>
#include <string.h>
#include <stdlib.h>
*/
import "C"
import (
	"bytes"
	_ "encoding/gob"
	"encoding/json"
	"fmt"
	"strings"
	"unsafe"
)

func main() {}

func getUintPointerValue(pointer *uint32, offset int) *C.uint {
	return (*C.uint)(unsafe.Pointer(uintptr(unsafe.Pointer(pointer)) + uintptr(offset)*uintptr(unsafe.Sizeof(C.uint(0)))))
}

func getCharPointerValue(pointer **C.char, offset int) **C.char {
	var c C.char
	return (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(pointer)) + uintptr(offset)*uintptr(unsafe.Sizeof(&c))))
}

func getUlongPointerValue(pointer *C.ulong, offset int) *C.ulong {
	return (*C.ulong)(unsafe.Pointer(uintptr(unsafe.Pointer(pointer)) + uintptr(offset)*uintptr(unsafe.Sizeof(C.ulong(0)))))
}

func getBytePointerValue(pointer *C.char, offset int) *C.char {
	return (*C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(pointer)) + uintptr(offset)*uintptr(unsafe.Sizeof(C.char(0)))))
}

//export questblob_init
func questblob_init(init *C.UDF_INIT, args *C.UDF_ARGS, msg *C.char) C.my_bool {
	init.maybe_null = 1
	if args.arg_count < 2 {
		s := C.CString("UnExpected err: args count len < 1")
		C.strcpy(msg, s)
		C.free(unsafe.Pointer(s))
		return 1
	}
	*getUintPointerValue(args.arg_type, 0) = C.STRING_RESULT
	*getUintPointerValue(args.arg_type, 1) = C.STRING_RESULT
	return 0
}

//export questblob_deinit
func questblob_deinit(init *C.UDF_INIT) {
	C.free(unsafe.Pointer(init.ptr))
}

//export questblob
func questblob(init *C.UDF_INIT, args *C.UDF_ARGS, result *C.char, length *C.ulong, is_null *C.char, error *C.char) *C.char {
	*is_null = 1
	chars := getCharPointerValue(args.args, 0)
	arglen := getUlongPointerValue(args.lengths, 0)
	names := getCharPointerValue(args.args, 1)
	namelen := getUlongPointerValue(args.lengths, 1)
	rb := C.GoBytes(unsafe.Pointer(*chars), C.int(*arglen))
	name := C.GoStringN(*names, C.int(*namelen))
	nameInfos := strings.Split(name, ".")

	_ = nameInfos[0] // 数据库表名
	_ = nameInfos[1] // 数据库字段名,实际情况中可根据这两个string,决定构造那个对象
	errStringPrefix := fmt.Sprintf("[Error]table:'%s' column:'%s' ", nameInfos[0], nameInfos[1])

	data := &PlayerAllQuestDbData{}
	err := data.FromDbBlob(bytes.NewBuffer(rb).Bytes())
	if err != nil {
		*is_null = 0
		errString := errStringPrefix + err.Error()
		*length = C.ulong(len([]byte(errString)))
		errS := C.CString(errString)
		init.ptr = errS
		return errS
	}

	rb, err = json.Marshal(data)
	if err != nil {
		*is_null = 0
		errString := errStringPrefix + " Marshal error: " + err.Error()
		*length = C.ulong(len([]byte(errString)))
		errS := C.CString(errString)
		init.ptr = errS
		return errS
	}

	*is_null = 0
	*length = C.ulong(len(rb))
	s := C.CString(string(rb))
	init.ptr = s
	return s
}

生成dll

windows运行build.bat,生成libquestblob.dll, 拷贝到mysql的plugin目录下。

使用

CREATE FUNCTION questblob RETURNS STRING SONAME 'libquestblob.dll'; #此语句执行一次即可
SELECT roleid, cast(questblob(t_quest_db.data, 't_quest_db.data') AS CHAR) FROM t_quest_db;

原文地址:https://www.cnblogs.com/fitness/p/8461921.html