diff --git a/common/recharge.go b/common/recharge.go index b22f866..b65dcee 100644 --- a/common/recharge.go +++ b/common/recharge.go @@ -150,13 +150,6 @@ const ( WithdrawOrderTypeAll ) -// 支付类型 -const ( - WithdrawTypeWallet = 3 - WithdrawTypeUPI = 4 - WithdrawTypeBank = 7 -) - type WithdrawOrder struct { ID uint `gorm:"primarykey"` UID int `gorm:"column:uid;not null;type:int(11)"` diff --git a/modules/pay/agropay/base.go b/modules/pay/agropay/base.go index b3873ee..c030d72 100644 --- a/modules/pay/agropay/base.go +++ b/modules/pay/agropay/base.go @@ -136,12 +136,12 @@ func (s *Sub) PackWithdrawReq() interface{} { EncryptEmail: util.CalculateMD5(r.Email), DeviceId: r.DeviceID, } - if r.PayType == common.WithdrawTypeBank { + if r.PayType == int64(common.PayTypeBank) { send.EntryType = "banks" send.AccountNo = r.CardNo send.IfscCode = r.PayCode send.BankName = r.BankName - } else if r.PayType == common.WithdrawTypeUPI { + } else if r.PayType == int64(common.PayTypeUPI) { send.EntryType = "upi" send.AccountNo = r.PayCode } else { diff --git a/modules/pay/mlpay/base.go b/modules/pay/mlpay/base.go index 4a8f9ea..60a6963 100644 --- a/modules/pay/mlpay/base.go +++ b/modules/pay/mlpay/base.go @@ -147,26 +147,31 @@ func (s *Sub) PackWithdrawReq() interface{} { func (s *Sub) CheckSign(str string) bool { log.Debug("callback:%v", s.Base.CallbackReq) - checkSign := "" - sign := str if s.Base.Opt == base.OPTPayCB { req := s.Base.CallbackReq.(*PayCallbackReq) s.Base.CallbackResp.OrderID = req.PartnerOrderNo s.Base.CallbackResp.APIOrderID = req.OrderNo s.Base.CallbackResp.Success = req.Status == 1 - checkSign = req.Sign + signStr := values.GetSignStrFormURLEncode(s.Base.C, req, "sign") + log.Debug("req:%v,pay signStr:%v", *req, signStr) + return checkSign(signStr, req.Sign, 0) } else if s.Base.Opt == base.OPTWithdrawCB { req := s.Base.CallbackReq.(*WithdrawCallbackReq) + signStr := values.GetSignStrFormURLEncode(s.Base.C, req, "sign") + log.Debug("req:%v,withdraw signStr:%v", *req, signStr) s.Base.CallbackResp.OrderID = req.PartnerWithdrawNo s.Base.CallbackResp.Success = req.Status == 1 + return checkSign(signStr, req.Sign, 1) } - if s.Base.KeyName == "" { - sign += "&key=" + s.Base.SignKey + return false +} + +func checkSign(str, sign string, t int) (pass bool) { + str = values.GetSignStrURLEncode(str, "sign") + if t == 0 { + str += "&key=" + key } else { - sign += "&" + s.Base.KeyName + "=" + s.Base.SignKey + str += "&key=" + withdrawKey } - ret := util.CalculateMD5(sign) - ret = strings.ToUpper(ret) - log.Info("SignStr:%v,SignMD5:%v", sign, ret) - return checkSign == ret + return strings.ToUpper(util.CalculateMD5(str)) == sign } diff --git a/modules/pay/values/values.go b/modules/pay/values/values.go index 9a7b9b9..ff2b190 100644 --- a/modules/pay/values/values.go +++ b/modules/pay/values/values.go @@ -6,6 +6,8 @@ import ( "io/ioutil" "math/rand" "net/http" + "net/url" + "reflect" "server/call" "server/common" "server/config" @@ -546,3 +548,87 @@ func WithdrawAmount(w PayWay, success bool, amount int64) { db.Mysql().UpdateW(&common.ConfigWithdrawChannels{}, map[string]interface{}{"amount": gorm.Expr("amount + ?", fee)}, fmt.Sprintf("channel_id = %d", int(w))) }) } + +// 根据body里的字段直接拼接出签名字符串 +func GetSignStrURLEncode(str string, pass ...string) string { + sortStrs := []string{} + all, err := url.ParseQuery(str) + if err != nil { + log.Error("err:%e", err) + return "" + } + for k := range all { + sortStrs = append(sortStrs, k) + } + + signStr := "" + sort.Strings(sortStrs) + for _, v := range sortStrs { + shouldPass := false + for _, s := range pass { + if v == s { + shouldPass = true + break + } + } + if shouldPass { + continue + } + tmp := all.Get(v) + if len(tmp) > 1 && tmp[0] == 34 { + tmp = tmp[1 : len(tmp)-1] + } + if len(tmp) == 0 { + continue + } + signStr += fmt.Sprintf("%v=%v", v, string(tmp)) + signStr += "&" + } + signStr = signStr[:len(signStr)-1] + log.Debug("signStr:%v", signStr) + return signStr +} + +func GetSignStrFormURLEncode(c *gin.Context, model interface{}, pass ...string) string { + sortStrs := []string{} + for k := range c.Request.URL.Query() { + sortStrs = append(sortStrs, k) + } + // ref := reflect.ValueOf(model) + reft := reflect.TypeOf(model) + if reft.Kind() == reflect.Ptr { + // ref = ref.Elem() + reft = reft.Elem() + } + signStr := "" + sort.Strings(sortStrs) + for _, v := range sortStrs { + shouldPass := false + for _, s := range pass { + if v == s { + shouldPass = true + break + } + } + if shouldPass { + continue + } + param := c.Query(v) + if len(param) == 0 { + continue + } + first := v[0] + tmpName := strings.ToUpper(string(first)) + v[1:] + field, ok := reft.FieldByName(tmpName) + if ok && field.Tag.Get("encode") == "1" { // 需要urlencode + param, _ = url.QueryUnescape(param) + } + + signStr += fmt.Sprintf("%v=%v", v, param) + signStr += "&" + } + fmt.Println(signStr) + signStr = signStr[:len(signStr)-1] + log.Debug("signStr:%v", signStr) + return signStr +}