1919import binascii
2020import gzip
2121import json
22+ import math
2223import os
2324import random
2425import tempfile
@@ -177,6 +178,9 @@ def _get_type(self):
177178 ('bitWidth' , self .bit_width )
178179 ])
179180
181+ def _encode_values (self , values ):
182+ return list (map (int if self .bit_width < 64 else str , values ))
183+
180184 def generate_column (self , size , name = None ):
181185 lower_bound , upper_bound = self ._get_generated_data_bounds ()
182186 return self .generate_range (size , lower_bound , upper_bound ,
@@ -187,42 +191,29 @@ def generate_range(self, size, lower, upper, name=None,
187191 values = np .random .randint (lower , upper , size = size , dtype = np .int64 )
188192 if include_extremes and size >= 2 :
189193 values [:2 ] = [lower , upper ]
190- values = list ( map ( int if self .bit_width < 64 else str , values ) )
194+ values = self ._encode_values ( values )
191195
192196 is_valid = self ._make_is_valid (size )
193197
194198 if name is None :
195199 name = self .name
196200 return PrimitiveColumn (name , size , is_valid , values )
197201
202+ @property
203+ def column_class (self ):
204+ return PrimitiveColumn
205+
198206
199207# Integer field that fulfils the requirements for the run ends field of REE.
200208# The integers are positive and in a strictly increasing sequence
201209class RunEndsField (IntegerField ):
202- # bit_width should only be one of 16/32/64
203210 def __init__ (self , name , bit_width , * , metadata = None ):
211+ assert bit_width in (16 , 32 , 64 )
204212 super ().__init__ (name , is_signed = True , bit_width = bit_width ,
205- nullable = False , metadata = metadata , min_value = 1 )
213+ nullable = False , metadata = metadata )
206214
207- def generate_range (self , size , lower , upper , name = None ,
208- include_extremes = False ):
209- rng = np .random .default_rng ()
210- # generate values that are strictly increasing with a min-value of
211- # 1, but don't go higher than the max signed value for the given
212- # bit width. We sort the values to ensure they are strictly increasing
213- # and set replace to False to avoid duplicates, ensuring a valid
214- # run-ends array.
215- values = rng .choice (2 ** (self .bit_width - 1 ) - 1 , size = size , replace = False )
216- values += 1
217- values = sorted (values )
218- values = list (map (int if self .bit_width < 64 else str , values ))
219- # RunEnds cannot be null, as such self.nullable == False and this
220- # will generate a validity map of all ones.
221- is_valid = self ._make_is_valid (size )
222-
223- if name is None :
224- name = self .name
225- return PrimitiveColumn (name , size , is_valid , values )
215+ def generate_column (self , size , name = None ):
216+ raise NotImplementedError ("cannot be generated directly" )
226217
227218
228219class DateField (IntegerField ):
@@ -1159,11 +1150,32 @@ def _get_children(self):
11591150 ]
11601151
11611152 def generate_column (self , size , name = None ):
1162- values = self .values_field .generate_column (size )
1163- run_ends = self .run_ends_field .generate_column (size )
1153+ # The `size` of a RunEndEncodedField is the logical length of the
1154+ # run-end-encoded column, so we choose a number of physical runs
1155+ # that's smaller.
1156+ if size > 0 :
1157+ num_runs = np .random .randint (1 , math .ceil (size * 0.75 ))
1158+ # Generate run ends
1159+ run_ends = np .random .choice (size - 1 , num_runs - 1 , replace = False ) + 1
1160+ run_ends .sort ()
1161+ run_ends = np .concat ((run_ends , [size ]))
1162+ assert len (run_ends ) == num_runs
1163+ assert len (set (run_ends )) == num_runs
1164+ assert (run_ends > 0 ).all ()
1165+ assert (run_ends <= size ).all ()
1166+ else :
1167+ num_runs = 0
1168+ run_ends = []
1169+ run_ends_is_valid = self ._make_is_valid (num_runs , null_probability = 0 )
1170+ run_ends = self .run_ends_field ._encode_values (run_ends )
1171+
1172+ run_end_column = self .run_ends_field .column_class (
1173+ self .run_ends_field .name , num_runs , run_ends_is_valid , run_ends )
1174+ values = self .values_field .generate_column (num_runs )
1175+
11641176 if name is None :
11651177 name = self .name
1166- return RunEndEncodedColumn (name , size , run_ends , values )
1178+ return RunEndEncodedColumn (name , size , run_end_column , values )
11671179
11681180
11691181class _BaseUnionField (Field ):
@@ -1746,11 +1758,14 @@ def generate_recursive_nested_case():
17461758
17471759def generate_run_end_encoded_case ():
17481760 fields = [
1749- RunEndEncodedField ('ree16' , 16 , get_field ('values' , 'int32' )),
1750- RunEndEncodedField ('ree32' , 32 , get_field ('values' , 'utf8' )),
1751- RunEndEncodedField ('ree64' , 64 , get_field ('values' , 'float32' )),
1761+ RunEndEncodedField ('ree16_int32' , 16 , get_field ('values' , 'int32' )),
1762+ RunEndEncodedField ('ree32_utf8' , 32 , get_field ('values' , 'utf8' )),
1763+ RunEndEncodedField ('ree64_float32' , 64 , get_field ('values' , 'float32' )),
1764+ RunEndEncodedField ('ree16_bool' , 64 , get_field ('values' , 'bool' )),
1765+ # Add a non-REE-encoded field to check column size correctness
1766+ BooleanField ('bool' ),
17521767 ]
1753- batch_sizes = [0 , 7 , 10 ]
1768+ batch_sizes = [0 , 7 , 20 ]
17541769 return _generate_file ("run_end_encoded" , fields , batch_sizes )
17551770
17561771
0 commit comments