diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-05-17 09:26:26 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-17 09:26:26 +0300 |
commit | 85232a5b26666854deae59cf950f744740dd5c37 (patch) | |
tree | 3af76d8c6ba3173ffd925336d902da058df4e02d /modules/sub_quadratic_attention.py | |
parent | 56a2672831751480f94a018f861f0143a8234ae8 (diff) | |
parent | 4b07f2f584596604c4499efb0b0295e96985080f (diff) |
Merge branch 'dev' into taesd-a
Diffstat (limited to 'modules/sub_quadratic_attention.py')
-rw-r--r-- | modules/sub_quadratic_attention.py | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 05595323..497568eb 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -179,7 +179,7 @@ def efficient_dot_product_attention( chunk_idx, min(query_chunk_size, q_tokens) ) - + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk compute_query_chunk_attn: ComputeQueryChunkAttn = partial( @@ -201,14 +201,15 @@ def efficient_dot_product_attention( key=key, value=value, ) - - # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, - # and pass slices to be mutated, instead of torch.cat()ing the returned slices - res = torch.cat([ - compute_query_chunk_attn( + + res = torch.zeros_like(query) + for i in range(math.ceil(q_tokens / query_chunk_size)): + attn_scores = compute_query_chunk_attn( query=get_query_chunk(i * query_chunk_size), key=key, value=value, - ) for i in range(math.ceil(q_tokens / query_chunk_size)) - ], dim=1) + ) + + res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores + return res |