diff --git a/server/internal/machine/api/machine.go b/server/internal/machine/api/machine.go index 67860cb5..684bd2ed 100644 --- a/server/internal/machine/api/machine.go +++ b/server/internal/machine/api/machine.go @@ -261,7 +261,7 @@ func (m *Machine) WsGuacamole(g *gin.Context) { return } - err = mi.IfUseSshTunnelChangeIpPort() + err = mi.IfUseSshTunnelChangeIpPort(true) if err != nil { return } diff --git a/server/internal/machine/mcm/machine.go b/server/internal/machine/mcm/machine.go index d8ae3482..61fe834d 100644 --- a/server/internal/machine/mcm/machine.go +++ b/server/internal/machine/mcm/machine.go @@ -5,6 +5,7 @@ import ( tagentity "mayfly-go/internal/tag/domain/entity" "mayfly-go/pkg/errorx" "mayfly-go/pkg/logx" + "mayfly-go/pkg/utils/netx" "net" "time" @@ -15,8 +16,8 @@ import ( type MachineInfo struct { Key string `json:"key"` // 缓存key Id uint64 `json:"id"` - Code string `json:"code"` Name string `json:"name"` + Code string `json:"code"` Protocol int `json:"protocol"` Ip string `json:"ip"` // IP地址 @@ -35,12 +36,12 @@ type MachineInfo struct { CodePath []string `json:"codePath"` } -func (m *MachineInfo) UseSshTunnel() bool { - return m.SshTunnelMachine != nil +func (mi *MachineInfo) UseSshTunnel() bool { + return mi.SshTunnelMachine != nil } -func (m *MachineInfo) GetTunnelId() string { - return fmt.Sprintf("machine:%d", m.Id) +func (mi *MachineInfo) GetTunnelId() string { + return fmt.Sprintf("machine:%d", mi.Id) } // 连接 @@ -48,7 +49,7 @@ func (mi *MachineInfo) Conn() (*Cli, error) { logx.Infof("[%s]机器连接:%s:%d", mi.Name, mi.Ip, mi.Port) // 如果使用了ssh隧道,则修改机器ip port为暴露的ip port - err := mi.IfUseSshTunnelChangeIpPort() + err := mi.IfUseSshTunnelChangeIpPort(false) if err != nil { return nil, errorx.NewBiz("ssh隧道连接失败: %s", err.Error()) } @@ -66,33 +67,39 @@ func (mi *MachineInfo) Conn() (*Cli, error) { } // 如果使用了ssh隧道,则修改机器ip port为暴露的ip port -func (me *MachineInfo) IfUseSshTunnelChangeIpPort() error { - if !me.UseSshTunnel() { +func (mi *MachineInfo) IfUseSshTunnelChangeIpPort(out bool) error { + if !mi.UseSshTunnel() { return nil } - originId := me.Id + originId := mi.Id if originId == 0 { // 随机设置一个id,如果使用了隧道则用于临时保存隧道 - me.Id = uint64(time.Now().Nanosecond()) + mi.Id = uint64(time.Now().Nanosecond()) } - sshTunnelMachine, err := GetSshTunnelMachine(int(me.SshTunnelMachine.Id), func(u uint64) (*MachineInfo, error) { - return me.SshTunnelMachine, nil + sshTunnelMachine, err := GetSshTunnelMachine(int(mi.SshTunnelMachine.Id), func(u uint64) (*MachineInfo, error) { + return mi.SshTunnelMachine, nil }) if err != nil { return err } - exposeIp, exposePort, err := sshTunnelMachine.OpenSshTunnel(me.GetTunnelId(), me.Ip, me.Port) + exposeIp, exposePort, err := sshTunnelMachine.OpenSshTunnel(mi.GetTunnelId(), mi.Ip, mi.Port) if err != nil { return err } + + // 是否获取局域网的本地IP + if out { + exposeIp = netx.GetOutBoundIP() + } + // 修改机器ip地址 - me.Ip = exposeIp - me.Port = exposePort + mi.Ip = exposeIp + mi.Port = exposePort // 代理之后置空跳板机信息,防止重复跳 - me.TempSshMachineId = me.SshTunnelMachine.Id - me.SshTunnelMachine = nil + mi.TempSshMachineId = mi.SshTunnelMachine.Id + mi.SshTunnelMachine = nil return nil } diff --git a/server/internal/machine/mcm/sshtunnel.go b/server/internal/machine/mcm/sshtunnel.go index 70b758c3..a6edcebe 100644 --- a/server/internal/machine/mcm/sshtunnel.go +++ b/server/internal/machine/mcm/sshtunnel.go @@ -89,7 +89,7 @@ func (stm *SshTunnelMachine) OpenSshTunnel(id string, ip string, port int) (expo return "", 0, err } - localHost := "127.0.0.1" + localHost := "0.0.0.0" localAddr := fmt.Sprintf("%s:%d", localHost, localPort) listener, err := net.Listen("tcp", localAddr) if err != nil { diff --git a/server/pkg/utils/netx/netx.go b/server/pkg/utils/netx/netx.go index 9da2aa91..aab24b0f 100644 --- a/server/pkg/utils/netx/netx.go +++ b/server/pkg/utils/netx/netx.go @@ -3,6 +3,7 @@ package netx import ( "mayfly-go/pkg/logx" "net" + "strings" "github.com/lionsoul2014/ip2region/binding/golang/xdb" ) @@ -68,3 +69,13 @@ func Ip2Region(ip string) string { } return region } + +func GetOutBoundIP() string { + conn, err := net.Dial("udp", "8.8.8.8:53") + if err != nil { + return "0.0.0.0" + } + localAddr := conn.LocalAddr().(*net.UDPAddr) + ip := strings.Split(localAddr.String(), ":")[0] + return ip +} diff --git a/server/pkg/utils/netx/netx_test.go b/server/pkg/utils/netx/netx_test.go new file mode 100644 index 00000000..7a592303 --- /dev/null +++ b/server/pkg/utils/netx/netx_test.go @@ -0,0 +1,10 @@ +package netx + +import ( + "fmt" + "testing" +) + +func TestIp(t *testing.T) { + fmt.Println(GetOutBoundIP()) +}