Skip to content

Commit

Permalink
fix bug when bucket contains 2+ scalar tensors only
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Lykov committed Oct 4, 2024
1 parent 64abce7 commit bb6949f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 2 additions & 2 deletions qtensor/compression/Compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(self, r2r_error=1e-3, r2r_threshold=1e-3):

def free_decompressed(self):
import cupy
print("Cleanup", len(self.decompressed_own))
#print("Cleanup", len(self.decompressed_own))
for x in self.decompressed_own:
del x
cupy.get_default_memory_pool().free_all_blocks()
Expand Down Expand Up @@ -192,7 +192,7 @@ def __init__(self, r2r_error=1e-3, r2r_threshold=1e-3):

def free_decompressed(self):
import cupy
print("Cleanup", len(self.decompressed_own))
#print("Cleanup", len(self.decompressed_own))
for x in self.decompressed_own:
#print(x)
#if x == None:
Expand Down
4 changes: 4 additions & 0 deletions qtensor/contraction_backends/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def process_bucket(self, bucket, no_sum=False):
)
if len(bucket)>1:
t = bucket[-1]
if len(t.indices) == 0:
print(f"Scalar tensor {t}, {t.data}")
accum = accum * t
return accum
total_ixs = sorted(
set().union(*[t.indices, accum.indices])
, key=int, reverse=True
Expand Down

0 comments on commit bb6949f

Please sign in to comment.