Implement compression ratio decoding (#230)#231
Conversation
|
/ok to test 044386b |
maxjeblick
left a comment
There was a problem hiding this comment.
Thanks a lot for the PR! Overall, the PR looks good, I left some comments in the code.
Please update the README with your method, as well.
| from kvpress.presses.scorer_press import ScorerPress | ||
|
|
||
|
|
||
| class CompressionRatioDecodingPress(DecodingPress): |
There was a problem hiding this comment.
CompressionRatioDecodingPress should also be defined as a dataclass, e.g. using
@dataclass
class CompressionRatioDecodingPress(DecodingPress):
target_compression_ratio: float = 0.5
target_size: int = field(default=1, init=False)
def __post_init__(self):
super().__post_init__()
assert 0 <= self.target_compression_ratio < 1, (
"target_compression_ratio must be between 0 and 1"
)
|
|
||
| def _resolve_total_tokens_seen(self, kwargs: dict) -> int: | ||
| if "position_ids" in kwargs and kwargs["position_ids"] is not None: | ||
| return int(kwargs["position_ids"].reshape(-1)[-1].item()) + 1 |
There was a problem hiding this comment.
THis should be equivalent to kwargs["position_ids"].max() + 1 ? IMO, .max() is easier to follow.
Signed-off-by: Fabio Massimo Ercoli <[email protected]>
044386b to
09513cc
Compare
|
Thank you very much, @maxjeblick, for the review! I tried to address your comments. |
|
/ok to test 044386b |
@maxjeblick, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/ |
|
/ok to test 09513cc |
PR description
Fixes #230
Checklist
Before submitting a PR, please make sure:
Tests are working (
make test)Code is formatted correctly (
make style, on errors try fix withmake format)Copyright header is included
All commits are signed-off using
git commit -s(new press)
mypress_press.pyis in thepressesdirectory(new press)
MyPressis in__init__.py(new press)
README.mdis updated with a 1 liner about the new press in the Available presses section(new press) New press is in the
default_presseslist intests/default_presses.py(new press) A docstring is provided that follows the same structure as the existing ones