package task import ( "bufio" "errors" "fmt" "io" "math/rand" "server/call" "server/pb" "server/util" "sync/atomic" "time" "github.com/gogo/protobuf/proto" "github.com/liangdas/mqant/log" timewheel "github.com/liangdas/mqant/module/modules/timer" "golang.org/x/net/websocket" ) const ( bufsize = 4096 ) var ( CurrentConnect *int32 StopConnect *int32 ) type WsWork struct { // lock *sync.Mutex addr string conn *websocket.Conn r *bufio.Reader register map[int]map[int]func(msg []byte) workID int tableID int seat uint32 disconnect bool *Work } func (w *WsWork) Connect() { addr := "ws://" + w.addr origin := "http://localhost/" //客户端地址 ws, err := websocket.Dial(addr, "", origin) //第二个参数是websocket子协议,可以为空 // fail := 0 // t := 10 for { if err == nil { break } // if fail >= 5 { // return // } log.Error("Connect err:%v", err) w.Wait(time.Duration(10) * time.Second) ws, err = websocket.Dial(addr, "", origin) // fail++ // t += 5 } atomic.AddInt32(CurrentConnect, 1) w.conn = ws w.r = bufio.NewReaderSize(ws, bufsize) w.StartHeartbeat() w.StartWork() w.ReadLoop() w.Close() // var buf = make([]byte, 1000) // for { // n, err := ws.Read(buf) // if err != nil { // log.Fatal(err) // } // fmt.Println("receive: ", string(buf[:n])) // } } func (w *WsWork) RegistFunc(mid, pid int, f func(msg []byte)) error { if _, ok := w.register[mid]; !ok { w.register[mid] = map[int]func(msg []byte){} } if _, ok := w.register[mid][pid]; !ok { w.register[mid][pid] = f } return nil } func (w *WsWork) StartHeartbeat() { if w.disconnect { return } timewheel.GetTimeWheel().AddTimer(10*time.Second, nil, func(arge interface{}) { w.StartHeartbeat() }) if err := w.Write(int(pb.ServerType_SERVER_TYPE_GATEWAY), int(pb.ServerGateReq_GatePingReq), nil); err != nil { log.Error("err:%v", err) } } func (w *WsWork) StartWork() { w.Login() } func (w *WsWork) Write(mID, pID int, msg proto.Message) error { // log.Debug("write to agent:%v,%v", mID, pID) if w.conn == nil { return errors.New("Client nil") } body := []byte{} var err error if msg != nil { body, err = proto.Marshal(msg) if err != nil { log.Error("err:%v", err) // fmt.Println(err) return err } } send := []byte{} length, err := util.IntToBytes(len(body)+10, 2) if err != nil { log.Error("length err:%v", err) return err } send = append(send, length...) module, err := util.IntToBytes(mID, 2) if err != nil { log.Error("module err:%v", err) return err } protocol, _ := util.IntToBytes(pID, 2) send = append(send, module...) send = append(send, protocol...) send = append(send, []byte{0, 0, 0, 0}...) send = append(send, body...) w.conn.Write(send) return nil } func (w *WsWork) readInt(n int) (ret int, err error) { // tmp := []byte{} tmp, err := w.readByte(n) // fmt.Println("readint:", tmp) if err != nil { log.Error("err:%v", err) return } ret, err = util.BytesToInt(tmp) return } func (w *WsWork) readByte(n int) (data []byte, err error) { for i := 0; i < n; i++ { var tmp byte tmp, err = w.r.ReadByte() // fmt.Println("readbyte:", tmp) if err != nil { log.Error("err:%v", err) return } data = append(data, tmp) } // fmt.Println("protobuf-bytes:", data) return } func (w *WsWork) ReadLoop() { for { // 第一步拿到数据包长度 length, err := w.readInt(2) if err != nil { log.Error("read err:%v", err) return } if length > bufsize { log.Error("max bufSize limit:%v,length:%v", bufsize, length) return } // 第二步拿到模块协议类型 moduleType, err := w.readInt(2) if err != nil { log.Error("err:%v", err) return } if moduleType == 0 { log.Error("invalid moduleType") return } if moduleType >= 3000 { moduleType = call.GetGameOriginID(moduleType) } // 第三步拿到协议类型 protocolType, err := w.readInt(2) if err != nil { log.Error("err:%v", err) return } if protocolType == 0 { log.Error("invalid protocolType") return } // 第四步拿到uid _, err = w.readInt(4) if err != nil { log.Error("err:%v", err) return } // log.Debug("uid:%v", uid) // 路由 // moduleName := pb.ModuleType_name[int32(moduleType)] // if moduleName == "" { // log.Error("unknow moduleType:%v", moduleType) // return // } // fmt.Println("==========================", moduleType, moduleName, protocolType) // var protocolName string // pr := call.GetProtocolType(moduleType) // protocolName = pr[int32(protocolType)] // if protocolName == "" { // log.Error("invalid protocolType:%v", protocolType) // return // } request := []byte{} if length > 10 { // 第五步拿到协议数据 request = make([]byte, length-10) // l, err := w.r.Read(request) l, err := io.ReadFull(w.r, request) if err != nil { log.Error("err:%v", err) return } if l != length-10 { log.Error("pack len:%v,read len:%v,module:%v,protocol:%v", length-10, l, moduleType, protocolType) return } } // 心跳包 if moduleType == 1000 && protocolType == int(pb.ServerGateResp_GatePingResp) { // fmt.Println("ping from server") continue } if r, ok1 := w.register[moduleType]; ok1 { if sr, ok2 := r[protocolType]; ok2 { sr(request) continue } } // fmt.Printf("unknow proto:%v,%v\n", moduleType, protocolType) } } func (this *WsWork) Wait(t ...time.Duration) { var t1 time.Duration if t == nil { s := rand.Intn(3000) + 1000 t1 = time.Duration(s) * time.Millisecond } else { t1 = t[0] } time.Sleep(t1) } func (this *WsWork) Close() { fmt.Printf("robot %v stop", this.Number) atomic.AddInt32(CurrentConnect, -1) atomic.AddInt32(StopConnect, 1) this.disconnect = true this.conn.Close() }