signsur4739379373 commited on
Commit
2cd734b
·
verified ·
1 Parent(s): 19cfc51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -12
app.py CHANGED
@@ -253,6 +253,7 @@ pipe.load_lora_weights(
253
  pipe.fuse_lora()
254
  print("lightning lora fused.")
255
 
 
256
  # --- 2. manual surgery for lokr (snofs) ---
257
  print("attempting manual lokr injection for snofs...")
258
 
@@ -274,32 +275,38 @@ try:
274
  prefixes.add(key.replace(".lokr_w1", ""))
275
 
276
  for prefix in prefixes:
277
- # extract weights
278
  w1 = state_dict[f"{prefix}.lokr_w1"].to(device, dtype=dtype)
279
  w2 = state_dict[f"{prefix}.lokr_w2"].to(device, dtype=dtype)
280
  alpha = state_dict.get(f"{prefix}.alpha", None)
281
 
282
- # calculate scaling
283
- # lokr usually uses alpha / sqrt(rank) or similar, but often just alpha is enough
284
- # if alpha is present, scale = alpha / w1.shape[0] (or similar convention)
285
- # here we will assume simple multiplication or alpha scaling if provided
286
- scale = lokr_scale
287
  if alpha is not None:
288
- scale *= (alpha / w1.shape[0]) # standard lora scaling convention, might vary for lokr
 
 
 
289
 
290
  # compute delta: kronecker product
291
  # w1: (a, b), w2: (c, d) -> result: (a*c, b*d)
292
- # torch.kron is (a*c, b*d)
293
- delta = torch.kron(w1, w2) * scale
294
 
295
  # find target layer in model
296
- # prefix example: "transformer_blocks.0.attn.add_k_proj"
297
- # pipe.transformer matches this structure directly
298
  path_parts = prefix.split('.')
299
  target = pipe.transformer
 
 
300
  try:
301
  for part in path_parts:
302
  target = getattr(target, part)
 
 
 
 
 
 
 
303
 
304
  # check shapes
305
  if target.weight.shape == delta.shape:
@@ -307,12 +314,14 @@ try:
307
  updates += 1
308
  else:
309
  print(f"shape mismatch for {prefix}: model {target.weight.shape} vs lora {delta.shape}")
310
- except AttributeError:
311
  print(f"layer not found: {prefix}")
312
 
313
  print(f"successfully injected {updates} lokr layers manually.")
314
 
315
  except Exception as e:
 
 
316
  print(f"lokr injection failed: {e}")
317
  print("running with lightning lora only.")
318
 
 
253
  pipe.fuse_lora()
254
  print("lightning lora fused.")
255
 
256
+
257
  # --- 2. manual surgery for lokr (snofs) ---
258
  print("attempting manual lokr injection for snofs...")
259
 
 
275
  prefixes.add(key.replace(".lokr_w1", ""))
276
 
277
  for prefix in prefixes:
278
+ # extract weights and FORCE TO DEVICE
279
  w1 = state_dict[f"{prefix}.lokr_w1"].to(device, dtype=dtype)
280
  w2 = state_dict[f"{prefix}.lokr_w2"].to(device, dtype=dtype)
281
  alpha = state_dict.get(f"{prefix}.alpha", None)
282
 
283
+ # handle scale/alpha math carefully
284
+ current_scale = lokr_scale
 
 
 
285
  if alpha is not None:
286
+ # alpha is a tensor, move it to gpu
287
+ if isinstance(alpha, torch.Tensor):
288
+ alpha = alpha.to(device, dtype=dtype)
289
+ current_scale *= (alpha / w1.shape[0])
290
 
291
  # compute delta: kronecker product
292
  # w1: (a, b), w2: (c, d) -> result: (a*c, b*d)
293
+ delta = torch.kron(w1, w2) * current_scale
 
294
 
295
  # find target layer in model
 
 
296
  path_parts = prefix.split('.')
297
  target = pipe.transformer
298
+
299
+ layer_found = True
300
  try:
301
  for part in path_parts:
302
  target = getattr(target, part)
303
+ except AttributeError:
304
+ layer_found = False
305
+
306
+ if layer_found:
307
+ # double check devices before adding
308
+ if target.weight.device != delta.device:
309
+ delta = delta.to(target.weight.device)
310
 
311
  # check shapes
312
  if target.weight.shape == delta.shape:
 
314
  updates += 1
315
  else:
316
  print(f"shape mismatch for {prefix}: model {target.weight.shape} vs lora {delta.shape}")
317
+ else:
318
  print(f"layer not found: {prefix}")
319
 
320
  print(f"successfully injected {updates} lokr layers manually.")
321
 
322
  except Exception as e:
323
+ import traceback
324
+ traceback.print_exc()
325
  print(f"lokr injection failed: {e}")
326
  print("running with lightning lora only.")
327