一个网页AI聊天,图像识别项目,使用到的技术栈有,Gin框架,GORM,rabbitmq,redis,eino框架,Vue框架等
项目地址:https://github.com/youngyangyang04/GopherAI
项目使用了ONNX模型本地推理
路由层
只有一个API接口
// 为接口路径设置handler函数
func ImageRouter(r *gin.RouterGroup) {
r.POST("/recognize", image.RecognizeImage)
}
控制器层
首先定义了响应的JSON格式结构体,除了通用的部分外只有一个字符串表示识别结果
type (
RecognizeImageResponse struct {
ClassName string `json:"class_name,omitempty"` // AI回答
controller.Response
}
)
API对应的handler函数
// /recognize接口的handler函数
func RecognizeImage(c *gin.Context) {
res := new(RecognizeImageResponse)
// 从HTTP请求中取文件
file, err := c.FormFile("image")
if err != nil {
log.Println("FormFile fail ", err)
c.JSON(http.StatusOK, res.CodeOf(code.CodeInvalidParams))
return
}
// 调用服务层,识别图像
className, err := image.RecognizeImage(file)
if err != nil {
log.Println("RecognizeImage fail ", err)
c.JSON(http.StatusOK, res.CodeOf(code.CodeServerBusy))
return
}
// 返回结果
res.Success()
res.ClassName = className
c.JSON(http.StatusOK, res)
}
服务层
首先会加载模型,然后打开图像,转成二进制buffer,最后交给模型推理
// 调用模型识别图像
func RecognizeImage(file *multipart.FileHeader) (string, error) {
// 模型路径
modelPath := "/home/ayanami/GopherAI/GopherAI-v1/models/mobilenetv2/mobilenetv2-7.onnx"
labelPath := "/home/ayanami/GopherAI/GopherAI-v1/models/imagenet_classes.txt"
inputH, inputW := 224, 224
// 加载模型
recognizer, err := image.NewImageRecognizer(modelPath, labelPath, inputH, inputW)
if err != nil {
log.Println("NewImageRecognizer fail err is : ", err)
return "", err
}
defer recognizer.Close()
// 打开上传的文件
src, err := file.Open()
if err != nil {
log.Println("file open fail err is : ", err)
return "", err
}
defer src.Close()
// 转成二进制buffer
buf, err := io.ReadAll(src)
if err != nil {
log.Println("io.ReadAll fail err is : ", err)
return "", err
}
// 推理
return recognizer.PredictFromBuffer(buf)
}
图像识别模块
首先定义了ImageRecognizer结构体,描述模型推理参数
type ImageRecognizer struct {
session *ort.Session[float32]
inputName string
outputName string
inputH int
inputW int
labels []string
inputTensor *ort.Tensor[float32]
outputTensor *ort.Tensor[float32]
}
创建ImageRecognizer,初始化一个可推理的神经网络
// NewImageRecognizer 创建识别器(自动使用默认 input/output 名称)
func NewImageRecognizer(modelPath, labelPath string, inputH, inputW int) (*ImageRecognizer, error) {
if inputH <= 0 || inputW <= 0 {
inputH, inputW = 224, 224
}
// 初始化 ONNX 环境(全局一次)
initOnce.Do(func() {
initErr = ort.InitializeEnvironment()
})
if initErr != nil {
return nil, fmt.Errorf("onnxruntime initialize error: %w", initErr)
}
// 预先创建输入输出 Tensor
inputShape := ort.NewShape(1, 3, int64(inputH), int64(inputW))
inData := make([]float32, inputShape.FlattenedSize())
inTensor, err := ort.NewTensor(inputShape, inData)
if err != nil {
return nil, fmt.Errorf("create input tensor failed: %w", err)
}
outShape := ort.NewShape(1, 1000)
outTensor, err := ort.NewEmptyTensor[float32](outShape)
if err != nil {
inTensor.Destroy()
return nil, fmt.Errorf("create output tensor failed: %w", err)
}
// 创建 Session
session, err := ort.NewSession[float32](
modelPath,
[]string{defaultInputName},
[]string{defaultOutputName},
[]*ort.Tensor[float32]{inTensor},
[]*ort.Tensor[float32]{outTensor},
)
if err != nil {
inTensor.Destroy()
outTensor.Destroy()
return nil, fmt.Errorf("create onnx session failed: %w", err)
}
// 读取 label 文件
labels, err := loadLabels(labelPath)
if err != nil {
session.Destroy()
inTensor.Destroy()
outTensor.Destroy()
return nil, err
}
return &ImageRecognizer{
session: session,
inputName: defaultInputName,
outputName: defaultOutputName,
inputH: inputH,
inputW: inputW,
labels: labels,
inputTensor: inTensor,
outputTensor: outTensor,
}, nil
}
核心图像识别逻辑,经历格式转换,调用推理(底层使用cgo),最后取结果最大值映射到label
// 核心推理函数
func (r *ImageRecognizer) PredictFromImage(img image.Image) (string, error) {
resizedImg := image.NewRGBA(image.Rect(0, 0, r.inputW, r.inputH))
// 缩放图片
draw.CatmullRom.Scale(resizedImg, resizedImg.Bounds(), img, img.Bounds(), draw.Over, nil)
h, w := r.inputH, r.inputW
ch := 3 // R, G, B
data := make([]float32, h*w*ch)
for y := 0; y < h; y++ {
for x := 0; x < w; x++ {
c := resizedImg.At(x, y)
r, g, b, _ := c.RGBA()
rf := float32(r>>8) / 255.0
gf := float32(g>>8) / 255.0
bf := float32(b>>8) / 255.0
// NCHW format
data[y*w+x] = rf
data[h*w+y*w+x] = gf
data[2*h*w+y*w+x] = bf
}
}
// 吸入tensor
inData := r.inputTensor.GetData()
copy(inData, data)
// 执行模型
if err := r.session.Run(); err != nil {
return "", fmt.Errorf("onnx run error: %w", err)
}
// 获取结果
outData := r.outputTensor.GetData()
if len(outData) == 0 {
return "", errors.New("empty output from model")
}
// 取最大值
maxIdx := 0
maxVal := outData[0]
for i := 1; i < len(outData); i++ {
if outData[i] > maxVal {
maxVal = outData[i]
maxIdx = i
}
}
// 映射label
if maxIdx >= 0 && maxIdx < len(r.labels) {
return r.labels[maxIdx], nil
}
return "Unknown", nil
}
项目启动流程
项目从main.go启动,会首先从配置文件加载配置项,随后初始化MySQL,调用函数从数据库加载对话历史到各个AIHelper,随后初始化Redis和RabbitMQ,最后启动Gin服务器
// 启动服务器
func StartServer(addr string, port int) error {
// 初始化URL路由规则并返回*gin.Engine
r := router.InitRouter()
//服务器静态资源路径映射关系,这里目前不需要
// r.Static(config.GetConfig().HttpFilePath, config.GetConfig().MusicFilePath)
// 在指定IP:port启动服务器
return r.Run(fmt.Sprintf("%s:%d", addr, port))
}
// 从数据库加载消息并初始化 AIHelperManager
func readDataFromDB() error {
// 获取全局AIHelperManager实例
manager := aihelper.GetGlobalManager()
// 从数据库读取所有消息
msgs, err := message.GetAllMessages()
if err != nil {
return err
}
// 遍历数据库消息
for i := range msgs {
m := &msgs[i]
// 默认openai模型
modelType := "1"
config := make(map[string]any)
// 创建对应的 AIHelper
helper, err := manager.GetOrCreateAIHelper(m.UserName, m.SessionID, modelType, config)
if err != nil {
log.Printf("[readDataFromDB] failed to create helper for user=%s session=%s: %v", m.UserName, m.SessionID, err)
continue
}
// log.Println("readDataFromDB init: ", helper.SessionID)
// 把历史对话记录加载到helper
helper.AddMessage(m.Content, m.UserName, m.IsUser, false)
}
log.Println("==============AIHelperManager init success==============")
return nil
}
func main() {
// 加载配置项(从config.toml)
conf := config.GetConfig()
host := conf.MainConfig.Host
port := conf.MainConfig.Port
//初始化mysql,建立MsSQL连接
if err := mysql.InitMysql(); err != nil {
log.Println("InitMysql error , " + err.Error())
return
}
//从数据库加载消息并初始化AIHelperManager
readDataFromDB()
//初始化redis(加载缓存)
redis.Init()
log.Println("==============redis init success==============")
// 初始化rabbitmq(初始化消息队列,启动消费者)
rabbitmq.InitRabbitMQ()
log.Println("==============rabbitmq init success==============")
// 启动服务器
err := StartServer(host, port)
if err != nil {
panic(err)
}
}
项目采用前后端分离模式,启动后端server后还需要启动前端Vue服务器




( ๑´•ω•) “(ㆆᴗㆆ)