Files
mayfly-go/server/internal/machine/mcm/terminal_session.go

273 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mcm
import (
"context"
"fmt"
"io"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"github.com/may-fly/cast"
"strings"
"time"
"unicode/utf8"
"github.com/gorilla/websocket"
)
const (
Resize = 1
Data = 2
Ping = 3
MsgSplit = "|"
)
type TerminalSession struct {
ID string
wsConn *websocket.Conn
terminal *Terminal
handler *TerminalHandler
recorder *Recorder
ctx context.Context
cancel context.CancelFunc
dataChan chan rune
tick *time.Ticker
}
type CreateTerminalSessionParam struct {
SessionId string
Cli *Cli
WsConn *websocket.Conn
Rows int
Cols int
Recorder *Recorder
LogCmd bool // 是否记录命令
CmdFilterFuncs []CmdFilterFunc // 命令过滤器
}
func NewTerminalSession(param *CreateTerminalSessionParam) (*TerminalSession, error) {
sessionId, rows, cols := param.SessionId, param.Rows, param.Cols
cli, ws, recorder := param.Cli, param.WsConn, param.Recorder
terminal, err := NewTerminal(cli)
if err != nil {
return nil, err
}
err = terminal.RequestPty("xterm-256color", rows, cols)
if err != nil {
return nil, err
}
err = terminal.Shell()
if err != nil {
return nil, err
}
if recorder != nil {
recorder.WriteHeader(rows-3, cols)
}
var handler *TerminalHandler
// 记录命令或者存在命令过滤器时,则创建对应的终端处理器
if param.LogCmd || param.CmdFilterFuncs != nil {
handler = &TerminalHandler{Parser: NewParser(120, 40), Filters: param.CmdFilterFuncs}
}
ctx, cancel := context.WithCancel(context.Background())
tick := time.NewTicker(time.Millisecond * time.Duration(60))
ts := &TerminalSession{
ID: sessionId,
wsConn: ws,
terminal: terminal,
handler: handler,
recorder: recorder,
ctx: ctx,
cancel: cancel,
dataChan: make(chan rune),
tick: tick,
}
// 清除终端内容
ts.WriteToWs("\033[2J\033[3J\033[1;1H")
return ts, nil
}
func (r TerminalSession) Start() {
go r.readFromTerminal()
go r.writeToWebsocket()
r.receiveWsMsg()
}
func (r TerminalSession) Stop() {
logx.Debug("close machine ssh terminal session")
r.tick.Stop()
r.cancel()
if r.terminal != nil {
if err := r.terminal.Close(); err != nil {
if err != io.EOF {
logx.Errorf("关闭机器ssh终端失败: %s", err.Error())
}
}
}
}
// 获取终端会话执行的所有命令
func (r TerminalSession) GetExecCmds() []*ExecutedCmd {
if r.handler != nil {
return r.handler.ExecutedCmds
}
return []*ExecutedCmd{}
}
func (ts TerminalSession) readFromTerminal() {
for {
select {
case <-ts.ctx.Done():
return
default:
rn, size, err := ts.terminal.ReadRune()
if err != nil {
if err != io.EOF {
logx.Error("机器ssh终端读取消息失败: ", err)
}
return
}
if size > 0 {
ts.dataChan <- rn
}
}
}
}
func (ts TerminalSession) writeToWebsocket() {
var buf []byte
for {
select {
case <-ts.ctx.Done():
return
case <-ts.tick.C:
if len(buf) == 0 {
continue
}
if ts.handler != nil {
ts.handler.HandleRead(buf)
}
s := string(buf)
if err := ts.WriteToWs(s); err != nil {
logx.Error("机器ssh终端发送消息至websocket失败: ", err)
return
}
// 如果记录器存在,则记录操作回放信息
if ts.recorder != nil {
ts.recorder.Lock()
ts.recorder.WriteData(OutPutType, s)
ts.recorder.Unlock()
}
buf = []byte{}
case data := <-ts.dataChan:
if data != utf8.RuneError {
p := make([]byte, utf8.RuneLen(data))
utf8.EncodeRune(p, data)
buf = append(buf, p...)
} else {
buf = append(buf, []byte("@")...)
}
}
}
}
type WsMsg struct {
Type int `json:"type"`
Msg string `json:"msg"`
Cols int `json:"cols"`
Rows int `json:"rows"`
}
// receiveWsMsg 接收客户端ws发送过来的消息并写入终端会话中。
func (ts *TerminalSession) receiveWsMsg() {
for {
select {
case <-ts.ctx.Done():
return
default:
// read websocket msg
_, wsData, err := ts.wsConn.ReadMessage()
if err != nil {
logx.Debugf("机器ssh终端读取websocket消息失败: %s", err.Error())
return
}
// 解析消息
msgObj, err := parseMsg(wsData)
if err != nil {
ts.WriteToWs(GetErrorContentRn("消息内容解析失败..."))
logx.Error("机器ssh终端消息解析失败: ", err)
return
}
switch msgObj.Type {
case Resize:
if msgObj.Cols > 0 && msgObj.Rows > 0 {
if err := ts.terminal.WindowChange(msgObj.Rows, msgObj.Cols); err != nil {
logx.Error("ssh pty change windows size failed")
}
}
case Data:
data := []byte(msgObj.Msg)
if ts.handler != nil {
if err := ts.handler.PreWriteHandle(data); err != nil {
ts.WriteToWs(err.Error())
// 发送命令终止指令
ts.terminal.Write([]byte{EOT})
continue
}
}
_, err := ts.terminal.Write([]byte(msgObj.Msg))
if err != nil {
logx.Errorf("写入数据至ssh终端失败: %s", err)
ts.WriteToWs(GetErrorContentRn(fmt.Sprintf("写入数据至ssh终端失败: %s", err.Error())))
}
case Ping:
_, err := ts.terminal.SshSession.SendRequest("ping", true, nil)
if err != nil {
ts.WriteToWs(GetErrorContentRn("终端连接已断开..."))
return
}
}
}
}
}
// WriteToWs 将消息写入websocket连接
func (ts *TerminalSession) WriteToWs(msg string) error {
return ts.wsConn.WriteMessage(websocket.TextMessage, []byte(msg))
}
// 解析消息
func parseMsg(msg []byte) (*WsMsg, error) {
// 消息格式为 msgType|msgContent 如果msgType为resize则为msgType|rows|cols
msgStr := string(msg)
// 查找第一个 "|" 的位置
index := strings.Index(msgStr, MsgSplit)
if index == -1 {
return nil, errorx.NewBiz("消息内容不符合指定规则")
}
// 获取消息类型, 提取第一个 "|" 之前的内容
msgType := cast.ToIntD(msgStr[:index], Ping)
// 其余内容则为消息内容
msgContent := msgStr[index+1:]
wsMsg := &WsMsg{Type: msgType, Msg: msgContent}
if msgType == Resize {
rowsAndCols := strings.Split(msgContent, MsgSplit)
wsMsg.Rows = cast.ToIntD(rowsAndCols[0], 80)
wsMsg.Cols = cast.ToIntD(rowsAndCols[1], 80)
}
return wsMsg, nil
}