[Go]GopherAI项目学习记录:AI聊天模块

一个网页AI聊天,图像识别项目,使用到的技术栈有,Gin框架,GORM,rabbitmq,redis,eino框架,Vue框架等

项目地址:https://github.com/youngyangyang04/GopherAI


AI聊天模块提供了调用大模型API或者本地Ollama模型的能力,登录后的用户可以创建AI会话进行聊天

下面开始自顶向下地介绍模块实现

路由层

router.go中创建了/AI路由组,该路由组需要JWT鉴权

// 初始化URL路由规则
func InitRouter() *gin.Engine {

    r := gin.Default()
    enterRouter := r.Group("/api/v1")   // 创建路由组
    {
        RegisterUserRouter(enterRouter.Group("/user"))
    }

    //后续登录的接口需要jwt鉴权
    {
        AIGroup := enterRouter.Group("/AI")
        AIGroup.Use(jwt.Auth())         // 使用JWT中间件
        AIRouter(AIGroup)
    }

    {
        ImageGroup := enterRouter.Group("/image")
        ImageGroup.Use(jwt.Auth())      // 使用JWT中间件
        ImageRouter(ImageGroup)
    }

    return r
}

/AI路由组下有AI聊天会话相关的API接口:

  • 获取会话列表
  • 创建会话并发送消息
  • 发送消息
  • 获取聊天历史
  • 创建会话并发送消息(流式传输)
  • 发送消息(流式传输)
// 为接口路径设置handler函数
func AIRouter(r *gin.RouterGroup) {

    // 聊天相关接口
    {
        r.GET("/chat/sessions", session.GetUserSessionsByUserName)                          // 获取会话列表
        r.POST("/chat/send-new-session", session.CreateSessionAndSendMessage)               // 创建会话+发消息
        r.POST("/chat/send", session.ChatSend)                                              // 发消息
        r.POST("/chat/history", session.ChatHistory)                                        // 获取聊天历史
        // r.POST("/chat/tts", AI.ChatSpeech)                  // ChatSpeechHandler
        r.POST("/chat/send-stream-new-session", session.CreateStreamSessionAndSendMessage)  // 创建会话+发消息(流式)
        r.POST("/chat/send-stream", session.ChatStreamSend)                                 // 发消息(流式)
    }
}

控制器层

实现了路由层接口的handler函数,首先依旧定义了接口的请求/响应JSON结构体,方便后面绑定JSON参数

type (
    GetUserSessionsResponse struct {
        controller.Response
        Sessions []model.SessionInfo `json:"sessions,omitempty"`
    }
    CreateSessionAndSendMessageRequest struct {
        UserQuestion string `json:"question" binding:"required"`  // 用户问题;
        ModelType    string `json:"modelType" binding:"required"` // 模型类型;
    }

    CreateSessionAndSendMessageResponse struct {
        AiInformation string `json:"Information,omitempty"` // AI回答
        SessionID     string `json:"sessionId,omitempty"`   // 当前会话ID
        controller.Response
    }

    ChatSendRequest struct {
        UserQuestion string `json:"question" binding:"required"`            // 用户问题;
        ModelType    string `json:"modelType" binding:"required"`           // 模型类型;
        SessionID    string `json:"sessionId,omitempty" binding:"required"` // 当前会话ID
    }

    ChatSendResponse struct {
        AiInformation string `json:"Information,omitempty"` // AI回答
        controller.Response
    }

    ChatHistoryRequest struct {
        SessionID string `json:"sessionId,omitempty" binding:"required"` // 当前会话ID
    }
    ChatHistoryResponse struct {
        History []model.History `json:"history"`
        controller.Response
    }
)

先来看获取会话列表的handler函数,首先会把JWT认证的时候注入到*gin.Context的username拿出来,然后调用服务层的GetUserSessionsByUserName函数获取该user的所有会话,最后构建响应并返回

// /chat/sessions接口的handler函数
func GetUserSessionsByUserName(c *gin.Context) {
    res := new(GetUserSessionsResponse)
    userName := c.GetString("userName") // From JWT middleware

    userSessions, err := session.GetUserSessionsByUserName(userName)
    if err != nil {
        c.JSON(http.StatusOK, res.CodeOf(code.CodeServerBusy))
        return
    }

    res.Success()
    res.Sessions = userSessions
    c.JSON(http.StatusOK, res)
}

再看看创建会话并发送消息的handler函数,一样的先获取了username,随后调用了服务层的CreateSessionAndSendMessage函数,在内部创建会话并发送消息,返回AI回答和当前会话ID,最后依旧构建响应并返回

// /chat/send-new-session接口的handler函数
func CreateSessionAndSendMessage(c *gin.Context) {
    req := new(CreateSessionAndSendMessageRequest)
    res := new(CreateSessionAndSendMessageResponse)
    userName := c.GetString("userName") // From JWT middleware
    if err := c.ShouldBindJSON(req); err != nil {
        c.JSON(http.StatusOK, res.CodeOf(code.CodeInvalidParams))
        return
    }
    //内部会创建会话并发送消息,并会将AI回答、当前会话返回
    session_id, aiInformation, code_ := session.CreateSessionAndSendMessage(userName, req.UserQuestion, req.ModelType)

    if code_ != code.CodeSuccess {
        c.JSON(http.StatusOK, res.CodeOf(code_))
        return
    }

    res.Success()
    res.AiInformation = aiInformation
    res.SessionID = session_id
    c.JSON(http.StatusOK, res)
}

再看看发送消息的handler函数,和上面基本一样,调用的是服务层的ChatSend函数

// /chat/send接口的handler函数
func ChatSend(c *gin.Context) {
    req := new(ChatSendRequest)
    res := new(ChatSendResponse)
    userName := c.GetString("userName") // From JWT middleware
    if err := c.ShouldBindJSON(req); err != nil {
        c.JSON(http.StatusOK, res.CodeOf(code.CodeInvalidParams))
        return
    }
    // 发送消息,并会将AI回答返回
    aiInformation, code_ := session.ChatSend(userName, req.SessionID, req.UserQuestion, req.ModelType)

    if code_ != code.CodeSuccess {
        c.JSON(http.StatusOK, res.CodeOf(code_))
        return
    }

    res.Success()
    res.AiInformation = aiInformation
    c.JSON(http.StatusOK, res)
}

再看看获取聊天历史的handler函数,一样的获取username,绑定参数流程,然后调用服务层的GetChatHistory函数,最后依旧构建响应并返回

// /chat/history接口的handler函数
func ChatHistory(c *gin.Context) {
    req := new(ChatHistoryRequest)
    res := new(ChatHistoryResponse)
    userName := c.GetString("userName") // From JWT middleware
    if err := c.ShouldBindJSON(req); err != nil {
        c.JSON(http.StatusOK, res.CodeOf(code.CodeInvalidParams))
        return
    }
    history, code_ := session.GetChatHistory(userName, req.SessionID)
    if code_ != code.CodeSuccess {
        c.JSON(http.StatusOK, res.CodeOf(code_))
        return
    }

    res.Success()
    res.History = history
    c.JSON(http.StatusOK, res)
}

再看看创建会话+发消息(流式)的handler函数,正常地获取username,绑定完JSON参数之后,开始设置SSE响应头,将HTTP连接升级为SSE长连接,随后调用服务层的CreateStreamSessionOnly函数,创建session,随后立即把刚刚创建的sessionID发给前端,这样前端可以更新侧边栏会话列表,然后调用服务层的StreamMessageToExistingSession函数进行流式传输

// /chat/send-stream-new-session接口的handler函数
func CreateStreamSessionAndSendMessage(c *gin.Context) {
    req := new(CreateSessionAndSendMessageRequest)
    userName := c.GetString("userName") // From JWT middleware
    if err := c.ShouldBindJSON(req); err != nil {
        c.JSON(http.StatusOK, gin.H{"error": "Invalid parameters"})
        return
    }

    // 设置SSE头
    c.Header("Content-Type", "text/event-stream")
    c.Header("Cache-Control", "no-cache")
    c.Header("Connection", "keep-alive")
    c.Header("Access-Control-Allow-Origin", "*")
    c.Header("X-Accel-Buffering", "no") // 禁止代理缓存

    // 先创建会话并立即把 sessionId 下发给前端,随后再开始流式输出
    sessionID, code_ := session.CreateStreamSessionOnly(userName, req.UserQuestion)
    if code_ != code.CodeSuccess {
        c.SSEvent("error", gin.H{"message": "Failed to create session"})
        return
    }

    // 先把 sessionId 通过 data 事件发送给前端,前端据此绑定当前会话,侧边栏即可出现新标签
    c.Writer.WriteString(fmt.Sprintf("data: {\"sessionId\": \"%s\"}\n\n", sessionID))
    c.Writer.Flush()

    // 然后开始把本次回答进行流式发送(包含最后的 [DONE])
    code_ = session.StreamMessageToExistingSession(userName, sessionID, req.UserQuestion, req.ModelType, http.ResponseWriter(c.Writer))
    if code_ != code.CodeSuccess {
        c.SSEvent("error", gin.H{"message": "Failed to send message"})
        return
    }
}

最后再看看发消息(流式)的handler函数,和上面的区别就是没有创建会话那一步

// /chat/send-stream接口的handler函数
func ChatStreamSend(c *gin.Context) {
    req := new(ChatSendRequest)
    userName := c.GetString("userName") // From JWT middleware
    if err := c.ShouldBindJSON(req); err != nil {
        c.JSON(http.StatusOK, gin.H{"error": "Invalid parameters"})
        return
    }

    // 设置SSE头
    c.Header("Content-Type", "text/event-stream")
    c.Header("Cache-Control", "no-cache")
    c.Header("Connection", "keep-alive")
    c.Header("Access-Control-Allow-Origin", "*")
    c.Header("X-Accel-Buffering", "no") // 禁止代理缓存

    code_ := session.ChatStreamSend(userName, req.SessionID, req.UserQuestion, req.ModelType, http.ResponseWriter(c.Writer))
    if code_ != code.CodeSuccess {
        c.SSEvent("error", gin.H{"message": "Failed to send message"})
        return
    }

}

服务层

首先来看获取用户所有会话ID的函数,使用全局aihelpermanager获取了用户的所有会话ID,随后遍历会话ID切片,构建所有会话信息结构体(ID+标题)的切片并返回

// 获取用户所有的会话
func GetUserSessionsByUserName(userName string) ([]model.SessionInfo, error) {
    manager := aihelper.GetGlobalManager()          // 获取全局aihelpermanager实例
    Sessions := manager.GetUserSessions(userName)   // 获取用户的所有会话ID

    var SessionInfos []model.SessionInfo

    // 构建所有会话信息的切片
    for _, session := range Sessions {
        SessionInfos = append(SessionInfos, model.SessionInfo{
            SessionID: session,
            Title:     session, // 暂时用sessionID作为标题,后续重构需要的时候可以更改
        })
    }

    return SessionInfos, nil
}

来看创建会话并发送消息的函数,首先会调用CreateSession函数创建一个新的session,随后调用GetOrCreateAIHelper函数获取AIHelper来管理消息,最后调用helper的GenerateResponse函数来获取获取AI回复,最后返回AI回复,会话ID等信息

func CreateSessionAndSendMessage(userName string, userQuestion string, modelType string) (string, string, code.Code) {
    // 1:创建一个新的会话
    newSession := &model.Session{
        ID:       uuid.New().String(),
        UserName: userName,
        Title:    userQuestion, // 可以根据需求设置标题,这边暂时用用户第一次的问题作为标题
    }
    createdSession, err := session.CreateSession(newSession)
    if err != nil {
        log.Println("CreateSessionAndSendMessage CreateSession error:", err)
        return "", "", code.CodeServerBusy
    }

    // 2:获取AIHelper并通过其管理消息
    manager := aihelper.GetGlobalManager()
    config := map[string]interface{}{
        "apiKey": "your-api-key", // TODO: 从配置中获取
    }
    helper, err := manager.GetOrCreateAIHelper(userName, createdSession.ID, modelType, config)
    if err != nil {
        log.Println("CreateSessionAndSendMessage GetOrCreateAIHelper error:", err)
        return "", "", code.AIModelFail
    }

    // 3:生成AI回复
    aiResponse, err_ := helper.GenerateResponse(userName, ctx, userQuestion)
    if err_ != nil {
        log.Println("CreateSessionAndSendMessage GenerateResponse error:", err_)
        return "", "", code.AIModelFail
    }

    return createdSession.ID, aiResponse.Content, code.CodeSuccess
}

来看发送消息的函数,和上面的比就少了一个创建新会话

// 发送消息 
func ChatSend(userName string, sessionID string, userQuestion string, modelType string) (string, code.Code) {
    // 1:获取AIHelper
    manager := aihelper.GetGlobalManager()
    config := map[string]interface{}{
        "apiKey": "your-api-key", // TODO: 从配置中获取
    }
    helper, err := manager.GetOrCreateAIHelper(userName, sessionID, modelType, config)
    if err != nil {
        log.Println("ChatSend GetOrCreateAIHelper error:", err)
        return "", code.AIModelFail
    }

    // 2:生成AI回复
    aiResponse, err_ := helper.GenerateResponse(userName, ctx, userQuestion)
    if err_ != nil {
        log.Println("ChatSend GenerateResponse error:", err_)
        return "", code.AIModelFail
    }

    return aiResponse.Content, code.CodeSuccess
}

来看获取会话历史消息的函数,会先获取全局aihelpermanager,然后获取该username下指定sessionID的aihelper,调用GetMassages获取所有历史消息(*model.Message类型的切片),然后根据消息顺序标记是否是用户消息,转化成历史消息格式(model.History类型的切片),最后返回

// 获取会话历史消息
func GetChatHistory(userName string, sessionID string) ([]model.History, code.Code) {
    // 获取AIHelper
    manager := aihelper.GetGlobalManager()
    helper, exists := manager.GetAIHelper(userName, sessionID)
    if !exists {
        return nil, code.CodeServerBusy
    }

    // 获取helper的所有历史消息
    messages := helper.GetMessages()

    // 转换消息为历史格式(根据消息顺序或内容判断用户/AI消息)
    history := make([]model.History, 0, len(messages))
    for i, msg := range messages {
        isUser := i%2 == 0
        history = append(history, model.History{
            IsUser:  isUser,
            Content: msg.Content,
        })
    }

    // 返回该会话的所有历史消息
    return history, code.CodeSuccess
}

来看只创建会话的函数,这里跟上面的一样

// 只创建会话
func CreateStreamSessionOnly(userName string, userQuestion string) (string, code.Code) {
    newSession := &model.Session{
        ID:       uuid.New().String(),
        UserName: userName,
        Title:    userQuestion,
    }
    createdSession, err := session.CreateSession(newSession)
    if err != nil {
        log.Println("CreateStreamSessionOnly CreateSession error:", err)
        return "", code.CodeServerBusy
    }
    return createdSession.ID, code.CodeSuccess
}

来看向已存在会话流式发送消息的函数,首先获取writer的flusher确保支持Flush,随后获取AIHelper,调用AIHelper的流式生成函数,传入回调函数,回调函数会发送这一次流式生成的消息段,最后发送结束标志

// 向已存在会话流式发送消息
func StreamMessageToExistingSession(userName string, sessionID string, userQuestion string, modelType string, writer http.ResponseWriter) code.Code {
    // 检查writer,确保 writer 支持 Flush
    flusher, ok := writer.(http.Flusher)
    if !ok {
        log.Println("StreamMessageToExistingSession: streaming unsupported")
        return code.CodeServerBusy
    }

    // 获取AIHelper
    manager := aihelper.GetGlobalManager()
    config := map[string]interface{}{
        "apiKey": "your-api-key", // TODO: 从配置中获取
    }
    helper, err := manager.GetOrCreateAIHelper(userName, sessionID, modelType, config)
    if err != nil {
        log.Println("StreamMessageToExistingSession GetOrCreateAIHelper error:", err)
        return code.AIModelFail
    }

    // 定义一个回调函数,AI生成一段执行一次
    cb := func(msg string) {
        // 直接发送数据,不转义
        // SSE 格式:data: 
<content>\n\n
        log.Printf("[SSE] Sending chunk: %s (len=%d)\n", msg, len(msg))
        _, err := writer.Write([]byte("data: " + msg + "\n\n"))
        if err != nil {
            log.Println("[SSE] Write error:", err)
            return
        }
        flusher.Flush() //  每次必须 flush 不然浏览器收不到
        log.Println("[SSE] Flushed")
    }

    // 调用AI流式生成
    log.Println("Calling StreamResponse...")
    _, err_ := helper.StreamResponse(userName, ctx, cb, userQuestion)   // 传入刚刚的回调
    if err_ != nil {
        log.Println("StreamMessageToExistingSession StreamResponse error:", err_)
        return code.AIModelFail
    }
    log.Println("StreamResponse completed successfully")

    // 发送结束标志
    _, err = writer.Write([]byte("data: [DONE]\n\n"))
    if err != nil {
        log.Println("StreamMessageToExistingSession write DONE error:", err)
        return code.AIModelFail
    }
    flusher.Flush()

    return code.CodeSuccess
}

来看创建并发送消息(流式)的函数,可以看到就是刚刚的函数封装一下

// 创建会话并发送消息(流式)
func CreateStreamSessionAndSendMessage(userName string, userQuestion string, modelType string, writer http.ResponseWriter) (string, code.Code) {

    sessionID, code_ := CreateStreamSessionOnly(userName, userQuestion)
    if code_ != code.CodeSuccess {
        return "", code_
    }

    code_ = StreamMessageToExistingSession(userName, sessionID, userQuestion, modelType, writer)
    if code_ != code.CodeSuccess {

        return sessionID, code_
    }

    return sessionID, code.CodeSuccess
}

来看流式发送消息的函数,可以看到也是封装一下刚刚的函数

// 发送消息(流式)
func ChatStreamSend(userName string, sessionID string, userQuestion string, modelType string, writer http.ResponseWriter) code.Code {

    return StreamMessageToExistingSession(userName, sessionID, userQuestion, modelType, writer)
}

AIHelper组件介绍

在介绍DAO层内容之前先介绍刚刚一直在用的AIHelper组件,AIHelper是一个结构体,存储了会话使用的AI模型,历史消息,会话ID,一个异步保存函数(用于将聊天消息异步推送到消息队列)和一个读写锁

// AIHelper AI助手结构体,包含消息历史和AI模型
type AIHelper struct {
    model    AIModel                                        // AI模型
    messages []*model.Message                               // 上下文记忆
    mu       sync.RWMutex                                   // 读写锁
    SessionID string                                        // 一个会话绑定一个AIHelper
    saveFunc  func(*model.Message) (*model.Message, error)  // 异步保存函数(发送到Rabbitmq)
}

先来看创建AIHelper的函数,初始化了AI模型,消息切片,会话ID和用于异步推送消息到消息队列的函数

// NewAIHelper 创建新的AIHelper实例
func NewAIHelper(model_ AIModel, SessionID string) *AIHelper {
    return &AIHelper{
        model:    model_,
        messages: make([]*model.Message, 0),
        // 异步推送到消息队列中的函数
        saveFunc: func(msg *model.Message) (*model.Message, error) {
            data := rabbitmq.GenerateMessageMQParam(msg.SessionID, msg.Content, msg.UserName, msg.IsUser)
            err := rabbitmq.RMQMessage.Publish(data)
            return msg, err
        },
        SessionID: SessionID,
    }
}

来看添加聊天消息的函数,先将消息封装成了Message结构体,随后将其append到helper自己内存中的消息切片中,最后还调用了helper的saveFunc将消息推送到消息队列

// addMessage 添加消息到内存中并调用自定义存储函数
func (a *AIHelper) AddMessage(Content string, UserName string, IsUser bool, Save bool) {
    // 封装为消息结构体
    userMsg := model.Message{
        SessionID: a.SessionID,
        Content:   Content,
        UserName:  UserName,
        IsUser:    IsUser,
    }
    // 添加到helper的切片成员中
    a.messages = append(a.messages, &userMsg)
    // 调用异步保存函数,推送到消息队列中
    if Save {
        a.saveFunc(&userMsg)
    }
}

此外还提供了用于修改saveFunc的函数

// SaveMessage 保存消息到数据库(通过回调函数避免循环依赖)
// 通过传入func,自己调用外部的保存函数,即可支持同步异步等多种策略
func (a *AIHelper) SetSaveFunc(saveFunc func(*model.Message) (*model.Message, error)) {
    a.saveFunc = saveFunc
}

用于获取所有对话消息历史的函数,可以看到就是返回了helper的消息切片副本

// GetMessages 获取所有消息历史
func (a *AIHelper) GetMessages() []*model.Message {
    a.mu.RLock()            // 要加锁
    defer a.mu.RUnlock()
    out := make([]*model.Message, len(a.messages))
    copy(out, a.messages)   
    return out              // 返回helper存的上下文记忆切片的副本
}

获取模型类型的函数

// GetModelType 获取模型类型
func (a *AIHelper) GetModelType() string {
    return a.model.GetModelType()
}

下面是用于同步生成AI回复的函数,在函数开始会先调用AddMessage存用户的提问到上下文,随后将用户提问Message转化为shcema,给AI框架使用,再调用AImodel的生成函数获取回复schema并转成Message,再次调用AddMessage存到上下文,最后返回

// 同步生成
func (a *AIHelper) GenerateResponse(userName string, ctx context.Context, userQuestion string) (*model.Message, error) {

    // 调用存储函数,存用户问题到上下文
    a.AddMessage(userQuestion, userName, true, true)

    a.mu.RLock()
    // 将model.Message转化成schema.Message,供AI使用
    messages := utils.ConvertToSchemaMessages(a.messages)
    a.mu.RUnlock()

    // 调用模型生成回复
    schemaMsg, err := a.model.GenerateResponse(ctx, messages)
    if err != nil {
        return nil, err
    }

    // 将schema.Message转化成model.Message
    modelMsg := utils.ConvertToModelMessage(a.SessionID, userName, schemaMsg)

    // 调用存储函数,存AI回答到上下文
    a.AddMessage(modelMsg.Content, userName, false, true)

    return modelMsg, nil
}

用于流式生成AI回复的函数流程一致,只是调用AImodel的函数变成流式的了

// 流式生成
func (a *AIHelper) StreamResponse(userName string, ctx context.Context, cb StreamCallback, userQuestion string) (*model.Message, error) {

    // 调用存储函数,存用户问题到上下文
    a.AddMessage(userQuestion, userName, true, true)

    a.mu.RLock()
    // 将model.Message转化成schema.Message,供AI使用
    messages := utils.ConvertToSchemaMessages(a.messages)
    a.mu.RUnlock()

    // 调用模型生成回复
    content, err := a.model.StreamResponse(ctx, messages, cb)
    if err != nil {
        return nil, err
    }
    // 转化成model.Message
    modelMsg := &model.Message{
        SessionID: a.SessionID,
        UserName:  userName,
        Content:   content,
        IsUser:    false,
    }

    // 调用存储函数,存AI回答到上下文
    a.AddMessage(modelMsg.Content, userName, false, true)

    return modelMsg, nil
}

AI模型接口实现

在common/aihelper/model.go中定义了AIModel接口,并编写了OpenAI和Ollama两个实现了该接口的结构体,下面是接口定义,定义了三个方法:

  • 同步生成AI回复
  • 流式生成AI回复
  • 获取AI模型类型
type StreamCallback func(msg string)
// AIModel 定义AI模型接口
type AIModel interface {
    GenerateResponse(ctx context.Context, messages []*schema.Message) (*schema.Message, error)
    StreamResponse(ctx context.Context, messages []*schema.Message, cb StreamCallback) (string, error)
    GetModelType() string
}

OpenAI实现

首先定义了OpenAIModel结构体,有一个eino框架提供的model.ToolCallingChatModel成员

type OpenAIModel struct {
    llm model.ToolCallingChatModel
}

初始化模型函数,获取API_KEY等参数后调用eino框架的函数创建LLM客户端

// 初始化模型
func NewOpenAIModel(ctx context.Context) (*OpenAIModel, error) {
    // 从系统环境变量获取API_KEY等信息
    key := os.Getenv("OPENAI_API_KEY")
    modelName := os.Getenv("OPENAI_MODEL_NAME")
    baseURL := os.Getenv("OPENAI_BASE_URL")

    // 创建底层模型客户端(使用eino)
    llm, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{
        BaseURL: baseURL,
        Model:   modelName,
        APIKey:  key,
    })
    if err != nil {
        return nil, fmt.Errorf("create openai model failed: %v", err)
    }
    return &OpenAIModel{llm: llm}, nil
}

同步生成,封装了eino框架提供的函数

func (o *OpenAIModel) GenerateResponse(ctx context.Context, messages []*schema.Message) (*schema.Message, error) {
    resp, err := o.llm.Generate(ctx, messages)
    if err != nil {
        return nil, fmt.Errorf("openai generate failed: %v", err)
    }
    return resp, nil
}

流式生成,首先会调用eino框架提供的流式输出函数拿到输出流,然后就开始循环读取流,读到消息会执行两个操作,第一是将消息拼接,用于最后返回,便于聊天消息存储,第二是实时调用回调函数,每生成一个token就发给前端

// 流式生成
func (o *OpenAIModel) StreamResponse(ctx context.Context, messages []*schema.Message, cb StreamCallback) (string, error) {
    // 拿到eino提供的流
    stream, err := o.llm.Stream(ctx, messages)
    if err != nil {
        return "", fmt.Errorf("openai stream failed: %v", err)
    }
    defer stream.Close()

    var fullResp strings.Builder

    // 循环读取
    for {
        msg, err := stream.Recv()
        if err == io.EOF {
            break
        }
        if err != nil {
            return "", fmt.Errorf("openai stream recv failed: %v", err)
        }
        if len(msg.Content) > 0 {
            fullResp.WriteString(msg.Content) // 聚合,拼接完整回答

            cb(msg.Content) // 实时调用cb函数,方便主动发送给前端
        }
    }

    return fullResp.String(), nil // 返回完整内容,方便后续存储
}

Ollama实现

结构体,一样封装了eino提供的model.ToolCallingChatModel

// OllamaModel Ollama模型实现
type OllamaModel struct {
    llm model.ToolCallingChatModel
}

初始化模型函数,和OpenAI实现不同的地方是调用的是ollama.NewChatModel

func NewOllamaModel(ctx context.Context, baseURL, modelName string) (*OllamaModel, error) {
    llm, err := ollama.NewChatModel(ctx, &ollama.ChatModelConfig{
        BaseURL: baseURL,
        Model:   modelName,
    })
    if err != nil {
        return nil, fmt.Errorf("create ollama model failed: %v", err)
    }
    return &OllamaModel{llm: llm}, nil
}

同步生成函数

func (o *OllamaModel) GenerateResponse(ctx context.Context, messages []*schema.Message) (*schema.Message, error) {
    resp, err := o.llm.Generate(ctx, messages)
    if err != nil {
        return nil, fmt.Errorf("ollama generate failed: %v", err)
    }
    return resp, nil
}

流式生成函数

func (o *OllamaModel) StreamResponse(ctx context.Context, messages []*schema.Message, cb StreamCallback) (string, error) {
    stream, err := o.llm.Stream(ctx, messages)
    if err != nil {
        return "", fmt.Errorf("ollama stream failed: %v", err)
    }
    defer stream.Close()
    var fullResp strings.Builder
    for {
        msg, err := stream.Recv()
        if err == io.EOF {
            break
        }
        if err != nil {
            return "", fmt.Errorf("openai stream recv failed: %v", err)
        }
        if len(msg.Content) > 0 {
            fullResp.WriteString(msg.Content) // 聚合
            cb(msg.Content)                   // 实时调用cb函数,方便主动发送给前端
        }
    }
    return fullResp.String(), nil //返回完整内容,方便后续存储
}

AIHelper工厂模式

使用工厂模式+单例模式来管理AIModel的注册和AIHelper的创建

首先定义了一个函数类型ModelCreator,用于创建AIModel并返回,支持传入一个用于生命周期管理的context和一个键为string,值为any的参数配置map

// ModelCreator 定义模型创建函数类型(需要 context)(其实就是AIModel的构造函数)
type ModelCreator func(ctx context.Context, config map[string]interface{}) (AIModel, error)

声明了一个AIModelFactory结构体,内部存了一个值为ModelCreator类型函数的map

该类采用单例模式,使用GetGlobalFactory函数获取全局单例,首次运行该函数会初始化全局AIModelFactory单例的creators并调用registerCreators注册模型

// AIModelFactory AI模型工厂(AIModel构造函数的集合map)
type AIModelFactory struct {
    creators map[string]ModelCreator
}

var (
    globalFactory *AIModelFactory   // 全局AIModelFactory单例
    factoryOnce   sync.Once
)

// GetGlobalFactory 获取全局单例
func GetGlobalFactory() *AIModelFactory {
    factoryOnce.Do(func() {     // 全局只初始化一次
        globalFactory = &AIModelFactory{
            creators: make(map[string]ModelCreator),
        }
        globalFactory.registerCreators()    // 注册模型
    })
    return globalFactory
}

来看AIModelFactory注册模型的函数,这里已经注册了两个,一个OpenAI,一个Ollama,其中Ollama需要传入的config中带有baseURL和modelName

// 注册模型(注册AIModel的构造函数到map里)
func (f *AIModelFactory) registerCreators() {
    //OpenAI
    f.creators["1"] = func(ctx context.Context, config map[string]interface{}) (AIModel, error) {
        return NewOpenAIModel(ctx)
    }

    //Ollama
    f.creators["2"] = func(ctx context.Context, config map[string]interface{}) (AIModel, error) {
        baseURL, _ := config["baseURL"].(string)
        modelName, ok := config["modelName"].(string)
        if !ok {
            return nil, fmt.Errorf("Ollama model requires modelName")
        }
        return NewOllamaModel(ctx, baseURL, modelName)
    }
}

实际主要使用以下两个函数,一个用于创建AIModel,一个可以直接创建AIHelper

// CreateAIModel 根据类型创建 AI 模型
func (f *AIModelFactory) CreateAIModel(ctx context.Context, modelType string, config map[string]interface{}) (AIModel, error) {
    creator, ok := f.creators[modelType]
    if !ok {
        return nil, fmt.Errorf("unsupported model type: %s", modelType)
    }
    return creator(ctx, config) // 调用对应AIModel的构造方法,构造AIModel并返回
}

// CreateAIHelper 一键创建 AIHelper
func (f *AIModelFactory) CreateAIHelper(ctx context.Context, modelType string, SessionID string, config map[string]interface{}) (*AIHelper, error) {
    // 创建AIModel
    model, err := f.CreateAIModel(ctx, modelType, config)
    if err != nil {
        return nil, err
    }
    // 创建AIHelper
    return NewAIHelper(model, SessionID), nil
}

此外还提供了后续注册新的AIModel的方法

// RegisterModel 可扩展注册
func (f *AIModelFactory) RegisterModel(modelType string, creator ModelCreator) {
    f.creators[modelType] = creator
}

AIHelper管理器

AIHelperManager组件提供了用户ID+用户会话级别管理AIHelper的功能,维护一个map,key是用户ID,value是一个map,map的key是该用户的会话ID,map的value是该会话的AIHelper

// AIHelperManager AI助手管理器,管理用户-会话-AIHelper的映射关系
type AIHelperManager struct {
    helpers map[string]map[string]*AIHelper // map[用户账号(唯一)]map[会话ID]*AIHelper
    mu      sync.RWMutex
}

提供函数用于创建AIHelperManager

// NewAIHelperManager 创建新的管理器实例
func NewAIHelperManager() *AIHelperManager {
    return &AIHelperManager{
        helpers: make(map[string]map[string]*AIHelper),
    }
}

一样采用单例模式

// 全局管理器实例
var globalManager *AIHelperManager
var once sync.Once

// GetGlobalManager 获取全局管理器实例
func GetGlobalManager() *AIHelperManager {
    once.Do(func() {
        globalManager = NewAIHelperManager()
    })
    return globalManager
}

获取或创建AIHelper,根据用户ID以及会话ID,内部使用了工厂模式来创建新的helper

// 获取或创建AIHelper
func (m *AIHelperManager) GetOrCreateAIHelper(userName string, sessionID string, modelType string, config map[string]interface{}) (*AIHelper, error) {
    m.mu.Lock()
    defer m.mu.Unlock()

    // 获取用户的会话映射
    userHelpers, exists := m.helpers[userName]
    if !exists {
        userHelpers =(map[string]*AIHelper)
        m.helpers[userName] = userHelpers
    }

    // 检查会话是否已存在,存在直接返回
    helper, exists := userHelpers[sessionID]
    if exists {
        return helper, nil
    }

    // 创建新的AIHelper,使用工厂
    factory := GetGlobalFactory()
    helper, err := factory.CreateAIHelper(ctx, modelType, sessionID, config)
    if err != nil {
        return nil, err
    }

    // 记录到AIHelperManager中
    userHelpers[sessionID] = helper
    return helper, nil
}

获取或者移除指定用户指定会话的helper

// 获取指定用户的指定会话的AIHelper
func (m *AIHelperManager) GetAIHelper(userName string, sessionID string) (*AIHelper, bool) {
    m.mu.RLock()
    defer m.mu.RUnlock()

    userHelpers, exists := m.helpers[userName]
    if !exists {
        return nil, false
    }

    helper, exists := userHelpers[sessionID]
    return helper, exists
}

// 移除指定用户的指定会话的AIHelper
func (m *AIHelperManager) RemoveAIHelper(userName string, sessionID string) {
    m.mu.Lock()
    defer m.mu.Unlock()

    userHelpers, exists := m.helpers[userName]
    if !exists {
        return
    }

    delete(userHelpers, sessionID)

    // 如果用户没有会话了,清理用户映射
    if len(userHelpers) == 0 {
        delete(m.helpers, userName)
    }
}

获取指定用户所有的会话

// 获取指定用户的所有会话ID
func (m *AIHelperManager) GetUserSessions(userName string) []string {
    m.mu.RLock()
    defer m.mu.RUnlock()

    userHelpers, exists := m.helpers[userName]
    if !exists {
        return []string{}
    }

    sessionIDs := make([]string, 0, len(userHelpers))
    //取出所有的key
    for sessionID := range userHelpers {
        sessionIDs = append(sessionIDs, sessionID)
    }

    return sessionIDs
}

统一使用AIHelperManager来管理用户会话的AIHelper

数据访问层

首先是AI聊天会话和聊天消息在数据库中的数据模型

首先是会话与会话信息

// 会话模型
type Session struct {
    ID        string         `gorm:"primaryKey;type:varchar(36)" json:"id"`     // 会话ID
    UserName  string         `gorm:"index;not null" json:"username"`            // 用户名
    Title     string         `gorm:"type:varchar(100)" json:"title"`            // 会话标题
    CreatedAt time.Time      `json:"created_at"`                                // 创建于
    UpdatedAt time.Time      `json:"updated_at"`                                // 更新于
    DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`                            // 删除于
}

// 会话信息
type SessionInfo struct {
    SessionID string `json:"sessionId"`                                         // 会话ID
    Title     string `json:"name"`                                              // 会话标题
}

然后是聊天消息

// 消息模型
type Message struct {
    ID        uint      `gorm:"primaryKey;autoIncrement" json:"id"`                     // ID
    SessionID string    `gorm:"index;not null;type:varchar(36)" json:"session_id"`      // 会话ID
    UserName  string    `gorm:"type:varchar(20)" json:"username"`                       // 用户名
    Content   string    `gorm:"type:text" json:"content"`                               // 消息内容
    IsUser    bool      `gorm:"not null;" json:"is_user"`                               // 是否是用户消息
    CreatedAt time.Time `json:"created_at"`                                             // 创建于
}

// 历史消息
type History struct {
    IsUser  bool   `json:"is_user"`     // 是否是用户消息
    Content string `json:"content"`     // 消息内容
}

以下是会话相关的数据库操作封装

// 获取指定用户的所有会话
func GetSessionsByUserName(UserName int64) ([]model.Session, error) {
    var sessions []model.Session
    err := mysql.DB.Where("user_name = ?", UserName).Find(&sessions).Error
    return sessions, err
}

// 创建会话到数据库中
func CreateSession(session *model.Session) (*model.Session, error) {
    err := mysql.DB.Create(session).Error
    return session, err
}

// 通过ID获取数据库中的会话
func GetSessionByID(sessionID string) (*model.Session, error) {
    var session model.Session
    err := mysql.DB.Where("id = ?", sessionID).First(&session).Error
    return &session, err
}

以下是聊天消息相关数据库操作的封装

// 获取会话的所有聊天消息记录(按时间排序)
func GetMessagesBySessionID(sessionID string) ([]model.Message, error) {
    var msgs []model.Message
    err := mysql.DB.Where("session_id = ?", sessionID).Order("created_at asc").Find(&msgs).Error
    return msgs, err
}

// 批量获取多个会话聊天记录
func GetMessagesBySessionIDs(sessionIDs []string) ([]model.Message, error) {
    var msgs []model.Message
    if len(sessionIDs) == 0 {
        return msgs, nil
    }
    err := mysql.DB.Where("session_id IN ?", sessionIDs).Order("created_at asc").Find(&msgs).Error
    return msgs, err
}

// 在数据库中写入聊天消息记录
func CreateMessage(message *model.Message) (*model.Message, error) {
    err := mysql.DB.Create(message).Error
    return message, err
}

// 获取数据库中全部聊天记录
func GetAllMessages() ([]model.Message, error) {
    var msgs []model.Message
    err := mysql.DB.Order("created_at asc").Find(&msgs).Error
    return msgs, err
}

RabbitMQ部分

面对大量用户使用过程中产生的AI聊天记录,我们都要写入数据库储存,但是频繁的数据库操作会导致严重的系统性能下降与甚至阻塞,所以我们使用了消息队列来对数据进行异步处理,这个项目使用了RabbitMQ

首先来看消息队列里的数据格式

type MessageMQParam struct {
    SessionID string `json:"session_id"`
    Content   string `json:"content"`
    UserName  string `json:"user_name"`
    IsUser    bool   `json:"is_user"`
}

对应的序列化函数

// 序列化为消息队列数据格式
func GenerateMessageMQParam(sessionID string, content string, userName string, IsUser bool) []byte {
    param := MessageMQParam{
        SessionID: sessionID,
        Content:   content,
        UserName:  userName,
        IsUser:    IsUser,
    }
    data, _ := json.Marshal(param)
    return data
}

消费者调用的函数,用于将消息队列中取出的数据反序列化并写入数据库

// 消费者回调函数
func MQMessage(msg *amqp.Delivery) error {
    // 解析JSON
    var param MessageMQParam
    err := json.Unmarshal(msg.Body, &param)
    if err != nil {
        return err
    }
    newMsg := &model.Message{
        SessionID: param.SessionID,
        Content:   param.Content,
        UserName:  param.UserName,
        IsUser:    param.IsUser,
    }

    //消费者异步插入到数据库中
    message.CreateMessage(newMsg)
    return nil
}

采用全局使用一个connection,每个队列一个channel的方式

// 全局connection对象
// 所有RabbitMQ都会复用该对象
var conn *amqp.Connection

// 初始化connection
func initConn() {
    // 从配置文件获取连接信息
    c := config.GetConfig()
    mqUrl := fmt.Sprintf(
        "amqp://%s:%s@%s:%d/%s",
        c.RabbitmqUsername, c.RabbitmqPassword, c.RabbitmqHost, c.RabbitmqPort, c.RabbitmqVhost,
    )
    log.Println("mqUrl is  " + mqUrl)
    // 建立TCP连接
    var err error
    conn, err = amqp.Dial(mqUrl)
    if err != nil {
        log.Fatalf("RabbitMQ connection failed: %v", err) // 输出错误并退出程序
    }
}

封装了RabbitMQ结构体

// RabbitMQ RabbitMQ结构体
type RabbitMQ struct {
    conn     *amqp.Connection
    channel  *amqp.Channel
    Exchange string
    Key      string
}

// NewRabbitMQ 创建RabbitMQ对象
func NewRabbitMQ(exchange string, key string) *RabbitMQ {
    return &RabbitMQ{Exchange: exchange, Key: key}
}

// Destroy 断开 channel 和 connection
func (r *RabbitMQ) Destroy() {
    _ = r.channel.Close()
    _ = r.conn.Close()
}

创建工作队列模式的RabbitMQ实例

// NewWorkRabbitMQ 创建Work模式的RabbitMQ实例
func NewWorkRabbitMQ(queue string) *RabbitMQ {
    // 初始化rabbitmq结构体
    rabbitmq := NewRabbitMQ("", queue)

    // 初始化connection
    if conn == nil {
        initConn()
    }
    rabbitmq.conn = conn

    // 创建channel
    var err error
    rabbitmq.channel, err = rabbitmq.conn.Channel()
    if err != nil {
        panic(err.Error())
    }

    return rabbitmq
}

发送消息到队列

// Publish 发送消息
func (r *RabbitMQ) Publish(message []byte) error {
    // 创建队列(不存在时)
    // 使用默认交换机的情况下,queue即为key
    _, err := r.channel.QueueDeclare(r.Key, false, false, false, false, nil)    // 不持久化
    if err != nil {
        return err
    }

    // 调用 channel 发送消息到队列(默认交换机)
    return r.channel.Publish(r.Exchange, r.Key, false, false,
        amqp.Publishing{
            ContentType: "text/plain",
            Body:        message,
        },
    )
}

消费者函数,从队列中接收消息并消费,handle函数就是刚刚写的回调

// Consume 消费者
// handle: 消息的消费业务函数,用于消费消息
func (r *RabbitMQ) Consume(handle func(msg *amqp.Delivery) error) {
    // 创建队列
    q, err := r.channel.QueueDeclare(r.Key, false, false, false, false, nil)
    if err != nil {
        panic(err)
    }

    // 接收消息,开始消费
    msgs, err := r.channel.Consume(q.Name, "", true, false, false, false, nil)
    if err != nil {
        panic(err)
    }

    // 处理消息
    for msg := range msgs {
        if err := handle(&msg); err != nil {
            fmt.Println(err.Error())
        }
    }
}

消息队列全局唯一,并提供初始化和销毁函数

var (
    RMQMessage *RabbitMQ
)

func InitRabbitMQ() {
    // 创建MQ并启动消费者
    // 无论调用多少次 NewWorkRabbitMQ,只会创建一次连接
    // 不同队列共用一个连接,可以保持不同队列消费消息的顺序

    RMQMessage = NewWorkRabbitMQ("Message")
    go RMQMessage.Consume(MQMessage)            // 启动消费者goruntine

}

// DestroyRabbitMQ 销毁RabbitMQ
func DestroyRabbitMQ() {
    RMQMessage.Destroy()
}

Utils

生成会话ID,使用google的uuid库,有着保证全球唯一,不依赖数据库自增,分布式安全等优点

// 生成会话ID
func GenerateUUID() string {
    return uuid.New().String()
}

将Message格式与Eino框架使用的shcema格式互转

// 将 schema 消息转换为数据库可存储的格式
func ConvertToModelMessage(sessionID string, userName string, msg *schema.Message) *model.Message {
    return &model.Message{
        SessionID: sessionID,
        UserName:  userName,
        Content:   msg.Content,
    }
}

// 将数据库消息转换为 schema 消息(供 AI 使用)
func ConvertToSchemaMessages(msgs []*model.Message) []*schema.Message {
    schemaMsgs := make([]*schema.Message, 0, len(msgs))
    for _, m := range msgs {
        role := schema.Assistant
        if m.IsUser {
            role = schema.User
        }
        schemaMsgs = append(schemaMsgs, &schema.Message{
            Role:    role,
            Content: m.Content,
        })
    }
    return schemaMsgs
}

评论

  1. Sankkooos
    Android Firefox 148.0
    10 小时前
    2026-3-24 19:08:24

    哇这篇好长Σ(っ °Д °;)っ

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇