@@ -77,7 +77,7 @@ def __init__(
77
77
self .custom_params = custom_params
78
78
79
79
# Process some special cases
80
- if self .temperature < _SAMPLING_EPS :
80
+ if 0 <= self .temperature < _SAMPLING_EPS :
81
81
# top_k = 1 means greedy sampling
82
82
self .temperature = 1.0
83
83
self .top_k = 1
@@ -93,9 +93,9 @@ def verify(self):
93
93
raise ValueError (f"top_p must be in (0, 1], got { self .top_p } ." )
94
94
if not 0.0 <= self .min_p <= 1.0 :
95
95
raise ValueError (f"min_p must be in [0, 1], got { self .min_p } ." )
96
- if self .top_k < - 1 or self .top_k == 0 :
96
+ if self .top_k < 1 or self .top_k == - 1 :
97
97
raise ValueError (
98
- f"top_k must be -1 (disable), or at least 1, " f" got { self .top_k } ."
98
+ f"top_k must be -1 (disable) or at least 1, got { self .top_k } ."
99
99
)
100
100
if not - 2.0 <= self .frequency_penalty <= 2.0 :
101
101
raise ValueError (
@@ -108,12 +108,12 @@ def verify(self):
108
108
)
109
109
if not 0.0 <= self .repetition_penalty <= 2.0 :
110
110
raise ValueError (
111
- "repetition_penalty must be in ( 0, 2], got "
111
+ "repetition_penalty must be in [ 0, 2], got "
112
112
f"{ self .repetition_penalty } ."
113
113
)
114
114
if not 0 <= self .min_new_tokens :
115
115
raise ValueError (
116
- f"min_new_tokens must be in ( 0, max_new_tokens], got "
116
+ f"min_new_tokens must be in [ 0, max_new_tokens], got "
117
117
f"{ self .min_new_tokens } ."
118
118
)
119
119
if self .max_new_tokens is not None :
@@ -123,7 +123,7 @@ def verify(self):
123
123
)
124
124
if not self .min_new_tokens <= self .max_new_tokens :
125
125
raise ValueError (
126
- f"min_new_tokens must be in ( 0, max_new_tokens({ self .max_new_tokens } )], got "
126
+ f"min_new_tokens must be in [ 0, max_new_tokens({ self .max_new_tokens } )], got "
127
127
f"{ self .min_new_tokens } ."
128
128
)
129
129
grammars = [
0 commit comments