zhiyuan8 commited on
Commit
55d2dd9
·
verified ·
1 Parent(s): 1617bfc

Update wkv.py

Browse files
Files changed (1) hide show
  1. wkv.py +1 -1
wkv.py CHANGED
@@ -282,7 +282,7 @@ class Rwkv_Tmix_x070(nn.Module):
282
  self.a0 + (xa @ self.a1) @ self.a2
283
  ) # a is "in-context learning rate"
284
  if self.args.wkv_has_gate:
285
- g = torch.sigmoid(xg @ self.g1) @ self.g2 + 1.0
286
  kk = k * self.k_k
287
  kk = F.normalize(kk.view(B, T, self.n_head, -1),
288
  p=2.0, dim=-1, eps=1e-4 if kk.dtype == torch.float16 else 1e-12).view(B, T, C)
 
282
  self.a0 + (xa @ self.a1) @ self.a2
283
  ) # a is "in-context learning rate"
284
  if self.args.wkv_has_gate:
285
+ g = torch.sigmoid(xg @ self.g1) @ self.g2 #+ 1.0
286
  kk = k * self.k_k
287
  kk = F.normalize(kk.view(B, T, self.n_head, -1),
288
  p=2.0, dim=-1, eps=1e-4 if kk.dtype == torch.float16 else 1e-12).view(B, T, C)