diff --git a/ruby/red-arrow-format/lib/arrow-format.rb b/ruby/red-arrow-format/lib/arrow-format.rb index aea210bfb18..2c8ecbf55c7 100644 --- a/ruby/red-arrow-format/lib/arrow-format.rb +++ b/ruby/red-arrow-format/lib/arrow-format.rb @@ -16,4 +16,5 @@ # under the License. require_relative "arrow-format/file-reader" +require_relative "arrow-format/streaming-reader" require_relative "arrow-format/version" diff --git a/ruby/red-arrow-format/lib/arrow-format/error.rb b/ruby/red-arrow-format/lib/arrow-format/error.rb index 39b0b8af156..d73c4082beb 100644 --- a/ruby/red-arrow-format/lib/arrow-format/error.rb +++ b/ruby/red-arrow-format/lib/arrow-format/error.rb @@ -19,6 +19,9 @@ class Error < StandardError end class ReadError < Error + end + + class FileReadError < ReadError attr_reader :buffer def initialize(buffer, message) @buffer = buffer diff --git a/ruby/red-arrow-format/lib/arrow-format/file-reader.rb b/ruby/red-arrow-format/lib/arrow-format/file-reader.rb index 29c7f5edd49..bf50bfd1cd3 100644 --- a/ruby/red-arrow-format/lib/arrow-format/file-reader.rb +++ b/ruby/red-arrow-format/lib/arrow-format/file-reader.rb @@ -15,57 +15,27 @@ # specific language governing permissions and limitations # under the License. -require_relative "array" -require_relative "error" -require_relative "field" -require_relative "record-batch" -require_relative "schema" -require_relative "type" +require_relative "streaming-reader" -require_relative "org/apache/arrow/flatbuf/binary" -require_relative "org/apache/arrow/flatbuf/bool" -require_relative "org/apache/arrow/flatbuf/date" -require_relative "org/apache/arrow/flatbuf/date_unit" -require_relative "org/apache/arrow/flatbuf/duration" -require_relative "org/apache/arrow/flatbuf/fixed_size_binary" -require_relative "org/apache/arrow/flatbuf/floating_point" +require_relative "org/apache/arrow/flatbuf/block" require_relative "org/apache/arrow/flatbuf/footer" -require_relative "org/apache/arrow/flatbuf/int" -require_relative "org/apache/arrow/flatbuf/interval" -require_relative "org/apache/arrow/flatbuf/interval_unit" -require_relative "org/apache/arrow/flatbuf/large_binary" -require_relative "org/apache/arrow/flatbuf/large_list" -require_relative "org/apache/arrow/flatbuf/large_utf8" -require_relative "org/apache/arrow/flatbuf/list" -require_relative "org/apache/arrow/flatbuf/map" -require_relative "org/apache/arrow/flatbuf/message" -require_relative "org/apache/arrow/flatbuf/null" -require_relative "org/apache/arrow/flatbuf/precision" -require_relative "org/apache/arrow/flatbuf/schema" -require_relative "org/apache/arrow/flatbuf/struct_" -require_relative "org/apache/arrow/flatbuf/time" -require_relative "org/apache/arrow/flatbuf/timestamp" -require_relative "org/apache/arrow/flatbuf/time_unit" -require_relative "org/apache/arrow/flatbuf/union" -require_relative "org/apache/arrow/flatbuf/union_mode" -require_relative "org/apache/arrow/flatbuf/utf8" module ArrowFormat class FileReader include Enumerable + include Readable - MAGIC = "ARROW1".b + MAGIC = "ARROW1".b.freeze MAGIC_BUFFER = IO::Buffer.for(MAGIC) START_MARKER_SIZE = MAGIC_BUFFER.size END_MARKER_SIZE = MAGIC_BUFFER.size - CONTINUATION = "\xFF\xFF\xFF\xFF".b - CONTINUATION_BUFFER = IO::Buffer.for(CONTINUATION) # # STREAMING_FORMAT_START_OFFSET = 8 - INT32_SIZE = 4 - FOOTER_SIZE_SIZE = INT32_SIZE - METADATA_SIZE_SIZE = INT32_SIZE + CONTINUATION_BUFFER = + IO::Buffer.for(MessagePullReader::CONTINUATION_STRING) + FOOTER_SIZE_FORMAT = :s32 + FOOTER_SIZE_SIZE = IO::Buffer.size_of(FOOTER_SIZE_FORMAT) def initialize(input) case input @@ -79,45 +49,75 @@ def initialize(input) validate @footer = read_footer + @record_batches = @footer.record_batches + @schema = read_schema(@footer.schema) end - def each - offset = STREAMING_FORMAT_START_OFFSET - schema = nil - continuation_size = CONTINUATION_BUFFER.size - # streaming format - loop do - continuation = @buffer.slice(offset, continuation_size) - unless continuation == CONTINUATION_BUFFER - raise ReadError.new(@buffer, "No valid continuation") - end - offset += continuation_size + def n_record_batches + @record_batches.size + end - metadata_size = @buffer.get_value(:u32, offset) - offset += METADATA_SIZE_SIZE - break if metadata_size.zero? + def read(i) + block = @record_batches[i] - metadata_data = @buffer.slice(offset, metadata_size) - offset += metadata_size - metadata = Org::Apache::Arrow::Flatbuf::Message.new(metadata_data) + offset = block.offset - body = @buffer.slice(offset, metadata.body_length) - header = metadata.header - case header - when Org::Apache::Arrow::Flatbuf::Schema - schema = read_schema(header) - when Org::Apache::Arrow::Flatbuf::RecordBatch - n_rows = header.length - columns = [] - nodes = header.nodes - buffers = header.buffers - schema.fields.each do |field| - columns << read_column(field, nodes, buffers, body) - end - yield(RecordBatch.new(schema, n_rows, columns)) - end + # If we can report property error information, we can use + # MessagePullReader here. + # + # message_pull_reader = MessagePullReader.new do |message, body| + # return read_record_batch(message.header, @schema, body) + # end + # chunk = @buffer.slice(offset, + # MessagePullReader::CONTINUATION_SIZE + + # MessagePullReader::METADATA_LENGTH_SIZE + + # block.meta_data_length + + # block.body_length) + # message_pull_reader.consume(chunk) - offset += metadata.body_length + continuation_size = CONTINUATION_BUFFER.size + continuation = @buffer.slice(offset, continuation_size) + unless continuation == CONTINUATION_BUFFER + raise FileReadError.new(@buffer, + "Invalid continuation: #{i}: " + + continuation.inspect) + end + offset += continuation_size + + metadata_length_type = MessagePullReader::METADATA_LENGTH_TYPE + metadata_length_size = MessagePullReader::METADATA_LENGTH_SIZE + metadata_length = @buffer.get_value(metadata_length_type, offset) + expected_metadata_length = + block.meta_data_length - + continuation_size - + metadata_length_size + unless metadata_length == expected_metadata_length + raise FileReadError.new(@buffer, + "Invalid metadata length #{i}: " + + "expected:#{expected_metadata_length} " + + "actual:#{metadata_length}") + end + offset += metadata_length_size + + metadata = @buffer.slice(offset, metadata_length) + fb_message = Org::Apache::Arrow::Flatbuf::Message.new(metadata) + fb_header = fb_message.header + unless fb_header.is_a?(Org::Apache::Arrow::Flatbuf::RecordBatch) + raise FileReadError.new(@buffer, + "Not a record batch message: #{i}: " + + fb_header.class.name) + end + offset += metadata_length + + body = @buffer.slice(offset, block.body_length) + read_record_batch(fb_header, @schema, body) + end + + def each + return to_enum(__method__) {n_record_batches} unless block_given? + + @record_batches.size.times do |i| + yield(read(i)) end end @@ -127,204 +127,28 @@ def validate FOOTER_SIZE_SIZE + END_MARKER_SIZE if @buffer.size < minimum_size - raise ReadError.new(@buffer, - "Input must be larger than or equal to " + - "#{minimum_size}: #{@buffer.size}") + raise FileReadError.new(@buffer, + "Input must be larger than or equal to " + + "#{minimum_size}: #{@buffer.size}") end start_marker = @buffer.slice(0, START_MARKER_SIZE) if start_marker != MAGIC_BUFFER - raise ReadError.new(@buffer, "No start marker") + raise FileReadError.new(@buffer, "No start marker") end - end_marker = @buffer.slice(@buffer.size - END_MARKER_SIZE, END_MARKER_SIZE) + end_marker = @buffer.slice(@buffer.size - END_MARKER_SIZE, + END_MARKER_SIZE) if end_marker != MAGIC_BUFFER - raise ReadError.new(@buffer, "No end marker") + raise FileReadError.new(@buffer, "No end marker") end end def read_footer footer_size_offset = @buffer.size - END_MARKER_SIZE - FOOTER_SIZE_SIZE - footer_size = @buffer.get_value(:u32, footer_size_offset) - footer_data = @buffer.slice(footer_size_offset - footer_size, footer_size) + footer_size = @buffer.get_value(FOOTER_SIZE_FORMAT, footer_size_offset) + footer_data = @buffer.slice(footer_size_offset - footer_size, + footer_size) Org::Apache::Arrow::Flatbuf::Footer.new(footer_data) end - - def read_field(fb_field) - fb_type = fb_field.type - case fb_type - when Org::Apache::Arrow::Flatbuf::Null - type = NullType.singleton - when Org::Apache::Arrow::Flatbuf::Bool - type = BooleanType.singleton - when Org::Apache::Arrow::Flatbuf::Int - case fb_type.bit_width - when 8 - if fb_type.signed? - type = Int8Type.singleton - else - type = UInt8Type.singleton - end - when 16 - if fb_type.signed? - type = Int16Type.singleton - else - type = UInt16Type.singleton - end - when 32 - if fb_type.signed? - type = Int32Type.singleton - else - type = UInt32Type.singleton - end - when 64 - if fb_type.signed? - type = Int64Type.singleton - else - type = UInt64Type.singleton - end - end - when Org::Apache::Arrow::Flatbuf::FloatingPoint - case fb_type.precision - when Org::Apache::Arrow::Flatbuf::Precision::SINGLE - type = Float32Type.singleton - when Org::Apache::Arrow::Flatbuf::Precision::DOUBLE - type = Float64Type.singleton - end - when Org::Apache::Arrow::Flatbuf::Date - case fb_type.unit - when Org::Apache::Arrow::Flatbuf::DateUnit::DAY - type = Date32Type.singleton - when Org::Apache::Arrow::Flatbuf::DateUnit::MILLISECOND - type = Date64Type.singleton - end - when Org::Apache::Arrow::Flatbuf::Time - case fb_type.bit_width - when 32 - case fb_type.unit - when Org::Apache::Arrow::Flatbuf::TimeUnit::SECOND - type = Time32Type.new(:second) - when Org::Apache::Arrow::Flatbuf::TimeUnit::MILLISECOND - type = Time32Type.new(:millisecond) - end - when 64 - case fb_type.unit - when Org::Apache::Arrow::Flatbuf::TimeUnit::MICROSECOND - type = Time64Type.new(:microsecond) - when Org::Apache::Arrow::Flatbuf::TimeUnit::NANOSECOND - type = Time64Type.new(:nanosecond) - end - end - when Org::Apache::Arrow::Flatbuf::Timestamp - unit = fb_type.unit.name.downcase.to_sym - type = TimestampType.new(unit, fb_type.timezone) - when Org::Apache::Arrow::Flatbuf::Interval - case fb_type.unit - when Org::Apache::Arrow::Flatbuf::IntervalUnit::YEAR_MONTH - type = YearMonthIntervalType.new - when Org::Apache::Arrow::Flatbuf::IntervalUnit::DAY_TIME - type = DayTimeIntervalType.new - when Org::Apache::Arrow::Flatbuf::IntervalUnit::MONTH_DAY_NANO - type = MonthDayNanoIntervalType.new - end - when Org::Apache::Arrow::Flatbuf::Duration - unit = fb_type.unit.name.downcase.to_sym - type = DurationType.new(unit) - when Org::Apache::Arrow::Flatbuf::List - type = ListType.new(read_field(fb_field.children[0])) - when Org::Apache::Arrow::Flatbuf::LargeList - type = LargeListType.new(read_field(fb_field.children[0])) - when Org::Apache::Arrow::Flatbuf::Struct - children = fb_field.children.collect {|child| read_field(child)} - type = StructType.new(children) - when Org::Apache::Arrow::Flatbuf::Union - children = fb_field.children.collect {|child| read_field(child)} - type_ids = fb_type.type_ids - case fb_type.mode - when Org::Apache::Arrow::Flatbuf::UnionMode::DENSE - type = DenseUnionType.new(children, type_ids) - when Org::Apache::Arrow::Flatbuf::UnionMode::SPARSE - type = SparseUnionType.new(children, type_ids) - end - when Org::Apache::Arrow::Flatbuf::Map - type = MapType.new(read_field(fb_field.children[0])) - when Org::Apache::Arrow::Flatbuf::Binary - type = BinaryType.singleton - when Org::Apache::Arrow::Flatbuf::LargeBinary - type = LargeBinaryType.singleton - when Org::Apache::Arrow::Flatbuf::Utf8 - type = UTF8Type.singleton - when Org::Apache::Arrow::Flatbuf::LargeUtf8 - type = LargeUTF8Type.singleton - when Org::Apache::Arrow::Flatbuf::FixedSizeBinary - type = FixedSizeBinaryType.new(fb_type.byte_width) - end - Field.new(fb_field.name, type, fb_field.nullable?) - end - - def read_schema(fb_schema) - fields = fb_schema.fields.collect do |fb_field| - read_field(fb_field) - end - Schema.new(fields) - end - - def read_column(field, nodes, buffers, body) - node = nodes.shift - length = node.length - - return field.type.build_array(length) if field.type.is_a?(NullType) - - validity_buffer = buffers.shift - if validity_buffer.length.zero? - validity = nil - else - validity = body.slice(validity_buffer.offset, validity_buffer.length) - end - - case field.type - when BooleanType, - NumberType, - TemporalType - values_buffer = buffers.shift - values = body.slice(values_buffer.offset, values_buffer.length) - field.type.build_array(length, validity, values) - when VariableSizeBinaryType - offsets_buffer = buffers.shift - values_buffer = buffers.shift - offsets = body.slice(offsets_buffer.offset, offsets_buffer.length) - values = body.slice(values_buffer.offset, values_buffer.length) - field.type.build_array(length, validity, offsets, values) - when FixedSizeBinaryType - values_buffer = buffers.shift - values = body.slice(values_buffer.offset, values_buffer.length) - field.type.build_array(length, validity, values) - when VariableSizeListType - offsets_buffer = buffers.shift - offsets = body.slice(offsets_buffer.offset, offsets_buffer.length) - child = read_column(field.type.child, nodes, buffers, body) - field.type.build_array(length, validity, offsets, child) - when StructType - children = field.type.children.collect do |child| - read_column(child, nodes, buffers, body) - end - field.type.build_array(length, validity, children) - when DenseUnionType - # dense union type doesn't have validity. - types = validity - offsets_buffer = buffers.shift - offsets = body.slice(offsets_buffer.offset, offsets_buffer.length) - children = field.type.children.collect do |child| - read_column(child, nodes, buffers, body) - end - field.type.build_array(length, types, offsets, children) - when SparseUnionType - # sparse union type doesn't have validity. - types = validity - children = field.type.children.collect do |child| - read_column(child, nodes, buffers, body) - end - field.type.build_array(length, types, children) - end - end end end diff --git a/ruby/red-arrow-format/lib/arrow-format/readable.rb b/ruby/red-arrow-format/lib/arrow-format/readable.rb new file mode 100644 index 00000000000..2d64d5387ff --- /dev/null +++ b/ruby/red-arrow-format/lib/arrow-format/readable.rb @@ -0,0 +1,242 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +require_relative "array" +require_relative "field" +require_relative "record-batch" +require_relative "schema" +require_relative "type" + +require_relative "org/apache/arrow/flatbuf/binary" +require_relative "org/apache/arrow/flatbuf/bool" +require_relative "org/apache/arrow/flatbuf/date" +require_relative "org/apache/arrow/flatbuf/date_unit" +require_relative "org/apache/arrow/flatbuf/duration" +require_relative "org/apache/arrow/flatbuf/fixed_size_binary" +require_relative "org/apache/arrow/flatbuf/floating_point" +require_relative "org/apache/arrow/flatbuf/int" +require_relative "org/apache/arrow/flatbuf/interval" +require_relative "org/apache/arrow/flatbuf/interval_unit" +require_relative "org/apache/arrow/flatbuf/large_binary" +require_relative "org/apache/arrow/flatbuf/large_list" +require_relative "org/apache/arrow/flatbuf/large_utf8" +require_relative "org/apache/arrow/flatbuf/list" +require_relative "org/apache/arrow/flatbuf/map" +require_relative "org/apache/arrow/flatbuf/message" +require_relative "org/apache/arrow/flatbuf/null" +require_relative "org/apache/arrow/flatbuf/precision" +require_relative "org/apache/arrow/flatbuf/schema" +require_relative "org/apache/arrow/flatbuf/struct_" +require_relative "org/apache/arrow/flatbuf/time" +require_relative "org/apache/arrow/flatbuf/timestamp" +require_relative "org/apache/arrow/flatbuf/time_unit" +require_relative "org/apache/arrow/flatbuf/union" +require_relative "org/apache/arrow/flatbuf/union_mode" +require_relative "org/apache/arrow/flatbuf/utf8" + +module ArrowFormat + module Readable + private + def read_schema(fb_schema) + fields = fb_schema.fields.collect do |fb_field| + read_field(fb_field) + end + Schema.new(fields) + end + + def read_field(fb_field) + fb_type = fb_field.type + case fb_type + when Org::Apache::Arrow::Flatbuf::Null + type = NullType.singleton + when Org::Apache::Arrow::Flatbuf::Bool + type = BooleanType.singleton + when Org::Apache::Arrow::Flatbuf::Int + case fb_type.bit_width + when 8 + if fb_type.signed? + type = Int8Type.singleton + else + type = UInt8Type.singleton + end + when 16 + if fb_type.signed? + type = Int16Type.singleton + else + type = UInt16Type.singleton + end + when 32 + if fb_type.signed? + type = Int32Type.singleton + else + type = UInt32Type.singleton + end + when 64 + if fb_type.signed? + type = Int64Type.singleton + else + type = UInt64Type.singleton + end + end + when Org::Apache::Arrow::Flatbuf::FloatingPoint + case fb_type.precision + when Org::Apache::Arrow::Flatbuf::Precision::SINGLE + type = Float32Type.singleton + when Org::Apache::Arrow::Flatbuf::Precision::DOUBLE + type = Float64Type.singleton + end + when Org::Apache::Arrow::Flatbuf::Date + case fb_type.unit + when Org::Apache::Arrow::Flatbuf::DateUnit::DAY + type = Date32Type.singleton + when Org::Apache::Arrow::Flatbuf::DateUnit::MILLISECOND + type = Date64Type.singleton + end + when Org::Apache::Arrow::Flatbuf::Time + case fb_type.bit_width + when 32 + case fb_type.unit + when Org::Apache::Arrow::Flatbuf::TimeUnit::SECOND + type = Time32Type.new(:second) + when Org::Apache::Arrow::Flatbuf::TimeUnit::MILLISECOND + type = Time32Type.new(:millisecond) + end + when 64 + case fb_type.unit + when Org::Apache::Arrow::Flatbuf::TimeUnit::MICROSECOND + type = Time64Type.new(:microsecond) + when Org::Apache::Arrow::Flatbuf::TimeUnit::NANOSECOND + type = Time64Type.new(:nanosecond) + end + end + when Org::Apache::Arrow::Flatbuf::Timestamp + unit = fb_type.unit.name.downcase.to_sym + type = TimestampType.new(unit, fb_type.timezone) + when Org::Apache::Arrow::Flatbuf::Interval + case fb_type.unit + when Org::Apache::Arrow::Flatbuf::IntervalUnit::YEAR_MONTH + type = YearMonthIntervalType.new + when Org::Apache::Arrow::Flatbuf::IntervalUnit::DAY_TIME + type = DayTimeIntervalType.new + when Org::Apache::Arrow::Flatbuf::IntervalUnit::MONTH_DAY_NANO + type = MonthDayNanoIntervalType.new + end + when Org::Apache::Arrow::Flatbuf::Duration + unit = fb_type.unit.name.downcase.to_sym + type = DurationType.new(unit) + when Org::Apache::Arrow::Flatbuf::List + type = ListType.new(read_field(fb_field.children[0])) + when Org::Apache::Arrow::Flatbuf::LargeList + type = LargeListType.new(read_field(fb_field.children[0])) + when Org::Apache::Arrow::Flatbuf::Struct + children = fb_field.children.collect {|child| read_field(child)} + type = StructType.new(children) + when Org::Apache::Arrow::Flatbuf::Union + children = fb_field.children.collect {|child| read_field(child)} + type_ids = fb_type.type_ids + case fb_type.mode + when Org::Apache::Arrow::Flatbuf::UnionMode::DENSE + type = DenseUnionType.new(children, type_ids) + when Org::Apache::Arrow::Flatbuf::UnionMode::SPARSE + type = SparseUnionType.new(children, type_ids) + end + when Org::Apache::Arrow::Flatbuf::Map + type = MapType.new(read_field(fb_field.children[0])) + when Org::Apache::Arrow::Flatbuf::Binary + type = BinaryType.singleton + when Org::Apache::Arrow::Flatbuf::LargeBinary + type = LargeBinaryType.singleton + when Org::Apache::Arrow::Flatbuf::Utf8 + type = UTF8Type.singleton + when Org::Apache::Arrow::Flatbuf::LargeUtf8 + type = LargeUTF8Type.singleton + when Org::Apache::Arrow::Flatbuf::FixedSizeBinary + type = FixedSizeBinaryType.new(fb_type.byte_width) + end + Field.new(fb_field.name, type, fb_field.nullable?) + end + + def read_record_batch(fb_record_batch, schema, body) + n_rows = fb_record_batch.length + nodes = fb_record_batch.nodes + buffers = fb_record_batch.buffers + columns = @schema.fields.collect do |field| + read_column(field, nodes, buffers, body) + end + RecordBatch.new(schema, n_rows, columns) + end + + def read_column(field, nodes, buffers, body) + node = nodes.shift + length = node.length + + return field.type.build_array(length) if field.type.is_a?(NullType) + + validity_buffer = buffers.shift + if validity_buffer.length.zero? + validity = nil + else + validity = body.slice(validity_buffer.offset, validity_buffer.length) + end + + case field.type + when BooleanType, + NumberType, + TemporalType + values_buffer = buffers.shift + values = body.slice(values_buffer.offset, values_buffer.length) + field.type.build_array(length, validity, values) + when VariableSizeBinaryType + offsets_buffer = buffers.shift + values_buffer = buffers.shift + offsets = body.slice(offsets_buffer.offset, offsets_buffer.length) + values = body.slice(values_buffer.offset, values_buffer.length) + field.type.build_array(length, validity, offsets, values) + when FixedSizeBinaryType + values_buffer = buffers.shift + values = body.slice(values_buffer.offset, values_buffer.length) + field.type.build_array(length, validity, values) + when VariableSizeListType + offsets_buffer = buffers.shift + offsets = body.slice(offsets_buffer.offset, offsets_buffer.length) + child = read_column(field.type.child, nodes, buffers, body) + field.type.build_array(length, validity, offsets, child) + when StructType + children = field.type.children.collect do |child| + read_column(child, nodes, buffers, body) + end + field.type.build_array(length, validity, children) + when DenseUnionType + # dense union type doesn't have validity. + types = validity + offsets_buffer = buffers.shift + offsets = body.slice(offsets_buffer.offset, offsets_buffer.length) + children = field.type.children.collect do |child| + read_column(child, nodes, buffers, body) + end + field.type.build_array(length, types, offsets, children) + when SparseUnionType + # sparse union type doesn't have validity. + types = validity + children = field.type.children.collect do |child| + read_column(child, nodes, buffers, body) + end + field.type.build_array(length, types, children) + end + end + end +end diff --git a/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb b/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb new file mode 100644 index 00000000000..ae231fccbc6 --- /dev/null +++ b/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb @@ -0,0 +1,200 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +require_relative "array" +require_relative "error" +require_relative "field" +require_relative "readable" +require_relative "record-batch" +require_relative "schema" +require_relative "type" + +module ArrowFormat + class MessagePullReader + CONTINUATION_TYPE = :s32 + CONTINUATION_SIZE = IO::Buffer.size_of(CONTINUATION_TYPE) + CONTINUATION_STRING = "\xFF\xFF\xFF\xFF".b.freeze + CONTINUATION_INT32 = -1 + METADATA_LENGTH_TYPE = :s32 + METADATA_LENGTH_SIZE = IO::Buffer.size_of(METADATA_LENGTH_TYPE) + + def initialize(&on_read) + @on_read = on_read + @buffer = IO::Buffer.new(0) + @metadata_length = nil + @body_length = nil + @state = :initial + end + + def next_required_size + case @state + when :initial + CONTINUATION_SIZE + when :metadata_length + METADATA_LENGTH_SIZE + when :metadata + @metadata_length + when :body + @body_length + when :eos + 0 + end + end + + def eos? + @state == :eos + end + + def consume(chunk) + return if eos? + + if @buffer.size.zero? + target = chunk + else + @buffer.resize(@buffer.size + chunk.size) + @buffer.copy(chunk) + target = @buffer + end + + loop do + next_size = next_required_size + break if next_size.zero? + + if target.size < next_size + @buffer.resize(target.size) if @buffer.size < target.size + @buffer.copy(target) + @buffer.resize(target.size) + return + end + + case @state + when :initial + consume_initial(target) + when :metadata_length + consume_metadata_length(target) + when :metadata + consume_metadata(target) + when :body + consume_body(target) + end + break if target.size == next_size + + target = target.slice(next_size) + end + end + + private + def consume_initial(target) + continuation = target.get_value(CONTINUATION_TYPE, 0) + unless continuation == CONTINUATION_INT32 + raise ReadError.new("Invalid continuation token: " + + continuation.inspect) + end + @state = :metadata_length + end + + def consume_metadata_length(target) + length = target.get_value(METADATA_LENGTH_TYPE, 0) + if length < 0 + raise ReadError.new("Negative metadata length: " + + length.inspect) + end + if length == 0 + @state = :eos + else + @metadata_length = length + @state = :metadata + end + end + + def consume_metadata(target) + metadata_buffer = target.slice(0, @metadata_length) + @message = Org::Apache::Arrow::Flatbuf::Message.new(metadata_buffer) + @body_length = @message.body_length + if @body_length < 0 + raise ReadError.new("Negative body length: " + + @body_length.inspect) + end + @state = :body + consume_body if @body_length.zero? + end + + def consume_body(target=nil) + body = target&.slice(0, @body_length) + @on_read.call(@message, body) + @state = :initial + end + end + + class StreamingPullReader + include Readable + + attr_reader :schema + def initialize(&on_read) + @on_read = on_read + @message_pull_reader = MessagePullReader.new do |message, body| + process_message(message, body) + end + @state = :schema + @schema = nil + end + + def next_required_size + @message_pull_reader.next_required_size + end + + def eos? + @message_pull_reader.eos? + end + + def consume(chunk) + @message_pull_reader.consume(chunk) + end + + private + def process_message(message, body) + case @state + when :schema + process_schema_message(message, body) + when :record_batch + process_record_batch_message(message, body) + end + end + + def process_schema_message(message, body) + header = message.header + unless header.is_a?(Org::Apache::Arrow::Flatbuf::Schema) + raise ReadError.new("Not a schema message: " + + header.inspect) + end + + @schema = read_schema(header) + # TODO: initial dictionaries support + @state = :record_batch + end + + def process_record_batch_message(message, body) + header = message.header + unless header.is_a?(Org::Apache::Arrow::Flatbuf::RecordBatch) + raise ReadError.new("Not a record batch message: " + + header.inspect) + end + + @on_read.call(read_record_batch(header, @schema, body)) + end + end +end diff --git a/ruby/red-arrow-format/lib/arrow-format/streaming-reader.rb b/ruby/red-arrow-format/lib/arrow-format/streaming-reader.rb new file mode 100644 index 00000000000..f11972c67a2 --- /dev/null +++ b/ruby/red-arrow-format/lib/arrow-format/streaming-reader.rb @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +require_relative "streaming-pull-reader" + +module ArrowFormat + class StreamingReader + include Enumerable + + attr_reader :schema + def initialize(input) + @input = input + @schema = nil + end + + def each + return to_enum(__method__) unless block_given? + + reader = StreamingPullReader.new do |record_batch| + @schema ||= reader.schema + yield(record_batch) + end + + buffer = "".b + loop do + next_size = reader.next_required_size + break if next_size.zero? + + next_chunk = @input.read(next_size, buffer) + break if next_chunk.nil? + + reader.consume(IO::Buffer.for(next_chunk)) + end + end + end +end diff --git a/ruby/red-arrow-format/test/test-file-reader.rb b/ruby/red-arrow-format/test/test-file-reader.rb deleted file mode 100644 index 6198f0cb96d..00000000000 --- a/ruby/red-arrow-format/test/test-file-reader.rb +++ /dev/null @@ -1,790 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -class TestFileReader < Test::Unit::TestCase - def setup - Dir.mktmpdir do |tmp_dir| - table = Arrow::Table.new(value: build_array) - @path = File.join(tmp_dir, "data.arrow") - table.save(@path) - File.open(@path, "rb") do |input| - @reader = ArrowFormat::FileReader.new(input) - yield - @reader = nil - end - GC.start - end - end - - def read - @reader.to_a.collect do |record_batch| - record_batch.to_h.tap do |hash| - hash.each do |key, value| - hash[key] = value.to_a - end - end - end - end - - def type - @type ||= @reader.first.schema.fields[0].type - end - - sub_test_case("Null") do - def build_array - Arrow::NullArray.new(3) - end - - def test_read - assert_equal([{"value" => [nil, nil, nil]}], - read) - end - end - - sub_test_case("Boolean") do - def build_array - Arrow::BooleanArray.new([true, nil, false]) - end - - def test_read - assert_equal([{"value" => [true, nil, false]}], - read) - end - end - - sub_test_case("Int8") do - def build_array - Arrow::Int8Array.new([-128, nil, 127]) - end - - def test_read - assert_equal([{"value" => [-128, nil, 127]}], - read) - end - end - - sub_test_case("UInt8") do - def build_array - Arrow::UInt8Array.new([0, nil, 255]) - end - - def test_read - assert_equal([{"value" => [0, nil, 255]}], - read) - end - end - - sub_test_case("Int16") do - def build_array - Arrow::Int16Array.new([-32768, nil, 32767]) - end - - def test_read - assert_equal([{"value" => [-32768, nil, 32767]}], - read) - end - end - - sub_test_case("UInt16") do - def build_array - Arrow::UInt16Array.new([0, nil, 65535]) - end - - def test_read - assert_equal([{"value" => [0, nil, 65535]}], - read) - end - end - - sub_test_case("Int32") do - def build_array - Arrow::Int32Array.new([-2147483648, nil, 2147483647]) - end - - def test_read - assert_equal([{"value" => [-2147483648, nil, 2147483647]}], - read) - end - end - - sub_test_case("UInt32") do - def build_array - Arrow::UInt32Array.new([0, nil, 4294967295]) - end - - def test_read - assert_equal([{"value" => [0, nil, 4294967295]}], - read) - end - end - - sub_test_case("Int64") do - def build_array - Arrow::Int64Array.new([ - -9223372036854775808, - nil, - 9223372036854775807 - ]) - end - - def test_read - assert_equal([ - { - "value" => [ - -9223372036854775808, - nil, - 9223372036854775807 - ] - } - ], - read) - end - end - - sub_test_case("UInt64") do - def build_array - Arrow::UInt64Array.new([0, nil, 18446744073709551615]) - end - - def test_read - assert_equal([{"value" => [0, nil, 18446744073709551615]}], - read) - end - end - - sub_test_case("Float32") do - def build_array - Arrow::FloatArray.new([-0.5, nil, 0.5]) - end - - def test_read - assert_equal([{"value" => [-0.5, nil, 0.5]}], - read) - end - end - - sub_test_case("Float64") do - def build_array - Arrow::DoubleArray.new([-0.5, nil, 0.5]) - end - - def test_read - assert_equal([{"value" => [-0.5, nil, 0.5]}], - read) - end - end - - sub_test_case("Date32") do - def setup(&block) - @date_2017_08_28 = 17406 - @date_2025_12_09 = 20431 - super(&block) - end - - def build_array - Arrow::Date32Array.new([@date_2017_08_28, nil, @date_2025_12_09]) - end - - def test_read - assert_equal([{"value" => [@date_2017_08_28, nil, @date_2025_12_09]}], - read) - end - end - - sub_test_case("Date64") do - def setup(&block) - @date_2017_08_28_00_00_00 = 1503878400000 - @date_2025_12_09_00_00_00 = 1765324800000 - super(&block) - end - - def build_array - Arrow::Date64Array.new([ - @date_2017_08_28_00_00_00, - nil, - @date_2025_12_09_00_00_00, - ]) - end - - def test_read - assert_equal([ - { - "value" => [ - @date_2017_08_28_00_00_00, - nil, - @date_2025_12_09_00_00_00, - ], - }, - ], - read) - end - end - - sub_test_case("Time32(:second)") do - def setup(&block) - @time_00_00_10 = 10 - @time_00_01_10 = 60 + 10 - super(&block) - end - - def build_array - Arrow::Time32Array.new(:second, [@time_00_00_10, nil, @time_00_01_10]) - end - - def test_read - assert_equal([{"value" => [@time_00_00_10, nil, @time_00_01_10]}], - read) - end - - def test_type - assert_equal(:second, type.unit) - end - end - - sub_test_case("Time32(:millisecond)") do - def setup(&block) - @time_00_00_10_000 = 10 * 1000 - @time_00_01_10_000 = (60 + 10) * 1000 - super(&block) - end - - def build_array - Arrow::Time32Array.new(:milli, - [@time_00_00_10_000, nil, @time_00_01_10_000]) - end - - def test_read - assert_equal([{"value" => [@time_00_00_10_000, nil, @time_00_01_10_000]}], - read) - end - - def test_type - assert_equal(:millisecond, type.unit) - end - end - - sub_test_case("Time64(:microsecond)") do - def setup(&block) - @time_00_00_10_000_000 = 10 * 1_000_000 - @time_00_01_10_000_000 = (60 + 10) * 1_000_000 - super(&block) - end - - def build_array - Arrow::Time64Array.new(:micro, - [ - @time_00_00_10_000_000, - nil, - @time_00_01_10_000_000, - ]) - end - - def test_read - assert_equal([ - { - "value" => [ - @time_00_00_10_000_000, - nil, - @time_00_01_10_000_000, - ], - }, - ], - read) - end - - def test_type - assert_equal(:microsecond, type.unit) - end - end - - sub_test_case("Time64(:nanosecond)") do - def setup(&block) - @time_00_00_10_000_000_000 = 10 * 1_000_000_000 - @time_00_01_10_000_000_000 = (60 + 10) * 1_000_000_000 - super(&block) - end - - def build_array - Arrow::Time64Array.new(:nano, - [ - @time_00_00_10_000_000_000, - nil, - @time_00_01_10_000_000_000, - ]) - end - - def test_read - assert_equal([ - { - "value" => [ - @time_00_00_10_000_000_000, - nil, - @time_00_01_10_000_000_000, - ], - }, - ], - read) - end - - def test_type - assert_equal(:nanosecond, type.unit) - end - end - - sub_test_case("Timestamp(:second)") do - def setup(&block) - @timestamp_2019_11_18_00_09_11 = 1574003351 - @timestamp_2025_12_16_05_33_58 = 1765863238 - super(&block) - end - - def build_array - Arrow::TimestampArray.new(:second, - [ - @timestamp_2019_11_18_00_09_11, - nil, - @timestamp_2025_12_16_05_33_58, - ]) - end - - def test_read - assert_equal([ - { - "value" => [ - @timestamp_2019_11_18_00_09_11, - nil, - @timestamp_2025_12_16_05_33_58, - ], - }, - ], - read) - end - end - - sub_test_case("Timestamp(:millisecond)") do - def setup(&block) - @timestamp_2019_11_18_00_09_11 = 1574003351 * 1_000 - @timestamp_2025_12_16_05_33_58 = 1765863238 * 1_000 - super(&block) - end - - def build_array - Arrow::TimestampArray.new(:milli, - [ - @timestamp_2019_11_18_00_09_11, - nil, - @timestamp_2025_12_16_05_33_58, - ]) - end - - def test_read - assert_equal([ - { - "value" => [ - @timestamp_2019_11_18_00_09_11, - nil, - @timestamp_2025_12_16_05_33_58, - ], - }, - ], - read) - end - end - - sub_test_case("Timestamp(:microsecond)") do - def setup(&block) - @timestamp_2019_11_18_00_09_11 = 1574003351 * 1_000_000 - @timestamp_2025_12_16_05_33_58 = 1765863238 * 1_000_000 - super(&block) - end - - def build_array - Arrow::TimestampArray.new(:micro, - [ - @timestamp_2019_11_18_00_09_11, - nil, - @timestamp_2025_12_16_05_33_58, - ]) - end - - def test_read - assert_equal([ - { - "value" => [ - @timestamp_2019_11_18_00_09_11, - nil, - @timestamp_2025_12_16_05_33_58, - ], - }, - ], - read) - end - end - - sub_test_case("Timestamp(:nanosecond)") do - def setup(&block) - @timestamp_2019_11_18_00_09_11 = 1574003351 * 1_000_000_000 - @timestamp_2025_12_16_05_33_58 = 1765863238 * 1_000_000_000 - super(&block) - end - - def build_array - Arrow::TimestampArray.new(:nano, - [ - @timestamp_2019_11_18_00_09_11, - nil, - @timestamp_2025_12_16_05_33_58, - ]) - end - - def test_read - assert_equal([ - { - "value" => [ - @timestamp_2019_11_18_00_09_11, - nil, - @timestamp_2025_12_16_05_33_58, - ], - }, - ], - read) - end - end - - sub_test_case("Timestamp(timezone)") do - def setup(&block) - @timezone = "UTC" - @timestamp_2019_11_18_00_09_11 = 1574003351 - @timestamp_2025_12_16_05_33_58 = 1765863238 - super(&block) - end - - def build_array - data_type = Arrow::TimestampDataType.new(:second, @timezone) - Arrow::TimestampArray.new(data_type, - [ - @timestamp_2019_11_18_00_09_11, - nil, - @timestamp_2025_12_16_05_33_58, - ]) - end - - def test_type - assert_equal([:second, @timezone], - [type.unit, type.timezone]) - end - end - - sub_test_case("YearMonthInterval") do - def build_array - Arrow::MonthIntervalArray.new([0, nil, 100]) - end - - def test_read - assert_equal([{"value" => [0, nil, 100]}], - read) - end - end - - sub_test_case("DayTimeInterval") do - def build_array - Arrow::DayTimeIntervalArray.new([ - {day: 1, millisecond: 100}, - nil, - {day: 3, millisecond: 300}, - ]) - end - - def test_read - assert_equal([ - { - "value" => [ - [1, 100], - nil, - [3, 300], - ], - }, - ], - read) - end - end - - sub_test_case("MonthDayNanoInterval") do - def build_array - Arrow::MonthDayNanoIntervalArray.new([ - { - month: 1, - day: 1, - nanosecond: 100, - }, - nil, - { - month: 3, - day: 3, - nanosecond: 300, - }, - ]) - end - - def test_read - assert_equal([ - { - "value" => [ - [1, 1, 100], - nil, - [3, 3, 300], - ], - }, - ], - read) - end - end - - sub_test_case("Duration(:second)") do - def build_array - Arrow::DurationArray.new(:second, [0, nil, 100]) - end - - def test_read - assert_equal([{"value" => [0, nil, 100]}], - read) - end - - def test_type - assert_equal(:second, type.unit) - end - end - - sub_test_case("Duration(:millisecond)") do - def build_array - Arrow::DurationArray.new(:milli, [0, nil, 100_000]) - end - - def test_read - assert_equal([{"value" => [0, nil, 100_000]}], - read) - end - - def test_type - assert_equal(:millisecond, type.unit) - end - end - - sub_test_case("Duration(:microsecond)") do - def build_array - Arrow::DurationArray.new(:micro, [0, nil, 100_000_000]) - end - - def test_read - assert_equal([{"value" => [0, nil, 100_000_000]}], - read) - end - - def test_type - assert_equal(:microsecond, type.unit) - end - end - - sub_test_case("Duration(:nanosecond)") do - def build_array - Arrow::DurationArray.new(:nano, [0, nil, 100_000_000_000]) - end - - def test_read - assert_equal([{"value" => [0, nil, 100_000_000_000]}], - read) - end - - def test_type - assert_equal(:nanosecond, type.unit) - end - end - - sub_test_case("Binary") do - def build_array - Arrow::BinaryArray.new(["Hello".b, nil, "World".b]) - end - - def test_read - assert_equal([{"value" => ["Hello".b, nil, "World".b]}], - read) - end - end - - sub_test_case("LargeBinary") do - def build_array - Arrow::LargeBinaryArray.new(["Hello".b, nil, "World".b]) - end - - def test_read - assert_equal([{"value" => ["Hello".b, nil, "World".b]}], - read) - end - end - - sub_test_case("UTF8") do - def build_array - Arrow::StringArray.new(["Hello", nil, "World"]) - end - - def test_read - assert_equal([{"value" => ["Hello", nil, "World"]}], - read) - end - end - - sub_test_case("LargeUTF8") do - def build_array - Arrow::LargeStringArray.new(["Hello", nil, "World"]) - end - - def test_read - assert_equal([{"value" => ["Hello", nil, "World"]}], - read) - end - end - - sub_test_case("FixedSizeBinary") do - def build_array - data_type = Arrow::FixedSizeBinaryDataType.new(4) - Arrow::FixedSizeBinaryArray.new(data_type, ["0124".b, nil, "abcd".b]) - end - - def test_read - assert_equal([{"value" => ["0124".b, nil, "abcd".b]}], - read) - end - end - - sub_test_case("List") do - def build_array - data_type = Arrow::ListDataType.new(name: "count", type: :int8) - Arrow::ListArray.new(data_type, [[-128, 127], nil, [-1, 0, 1]]) - end - - def test_read - assert_equal([{"value" => [[-128, 127], nil, [-1, 0, 1]]}], - read) - end - end - - sub_test_case("LargeList") do - def build_array - data_type = Arrow::LargeListDataType.new(name: "count", type: :int8) - Arrow::LargeListArray.new(data_type, [[-128, 127], nil, [-1, 0, 1]]) - end - - def test_read - assert_equal([{"value" => [[-128, 127], nil, [-1, 0, 1]]}], - read) - end - end - - sub_test_case("Struct") do - def build_array - data_type = Arrow::StructDataType.new(count: :int8, - visible: :boolean) - Arrow::StructArray.new(data_type, [[-128, nil], nil, [nil, true]]) - end - - def test_read - assert_equal([ - { - "value" => [ - [-128, nil], - nil, - [nil, true], - ], - }, - ], - read) - end - end - - sub_test_case("DenseUnion") do - def build_array - fields = [ - Arrow::Field.new("number", :int8), - Arrow::Field.new("text", :string), - ] - type_ids = [11, 13] - data_type = Arrow::DenseUnionDataType.new(fields, type_ids) - types = Arrow::Int8Array.new([11, 13, 11, 13, 13]) - value_offsets = Arrow::Int32Array.new([0, 0, 1, 1, 2]) - children = [ - Arrow::Int8Array.new([1, nil]), - Arrow::StringArray.new(["a", "b", "c"]) - ] - Arrow::DenseUnionArray.new(data_type, - types, - value_offsets, - children) - end - - def test_read - assert_equal([{"value" => [1, "a", nil, "b", "c"]}], - read) - end - end - - sub_test_case("SparseUnion") do - def build_array - fields = [ - Arrow::Field.new("number", :int8), - Arrow::Field.new("text", :string), - ] - type_ids = [11, 13] - data_type = Arrow::SparseUnionDataType.new(fields, type_ids) - types = Arrow::Int8Array.new([11, 13, 11, 13, 11]) - children = [ - Arrow::Int8Array.new([1, nil, nil, nil, 5]), - Arrow::StringArray.new([nil, "b", nil, "d", nil]) - ] - Arrow::SparseUnionArray.new(data_type, types, children) - end - - def test_read - assert_equal([{"value" => [1, "b", nil, "d", 5]}], - read) - end - end - - sub_test_case("Map") do - def build_array - data_type = Arrow::MapDataType.new(:string, :int8) - Arrow::MapArray.new(data_type, - [ - {"a" => -128, "b" => 127}, - nil, - {"c" => nil}, - ]) - end - - def test_read - assert_equal([ - { - "value" => [ - {"a" => -128, "b" => 127}, - nil, - {"c" => nil}, - ], - }, - ], - read) - end - end -end diff --git a/ruby/red-arrow-format/test/test-reader.rb b/ruby/red-arrow-format/test/test-reader.rb new file mode 100644 index 00000000000..8095adfd50f --- /dev/null +++ b/ruby/red-arrow-format/test/test-reader.rb @@ -0,0 +1,872 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +module ReaderTests + class << self + def included(base) + base.class_eval do + sub_test_case("Null") do + def build_array + Arrow::NullArray.new(3) + end + + def test_read + assert_equal([{"value" => [nil, nil, nil]}], + read) + end + end + + sub_test_case("Boolean") do + def build_array + Arrow::BooleanArray.new([true, nil, false]) + end + + def test_read + assert_equal([{"value" => [true, nil, false]}], + read) + end + end + + sub_test_case("Int8") do + def build_array + Arrow::Int8Array.new([-128, nil, 127]) + end + + def test_read + assert_equal([{"value" => [-128, nil, 127]}], + read) + end + end + + sub_test_case("UInt8") do + def build_array + Arrow::UInt8Array.new([0, nil, 255]) + end + + def test_read + assert_equal([{"value" => [0, nil, 255]}], + read) + end + end + + sub_test_case("Int16") do + def build_array + Arrow::Int16Array.new([-32768, nil, 32767]) + end + + def test_read + assert_equal([{"value" => [-32768, nil, 32767]}], + read) + end + end + + sub_test_case("UInt16") do + def build_array + Arrow::UInt16Array.new([0, nil, 65535]) + end + + def test_read + assert_equal([{"value" => [0, nil, 65535]}], + read) + end + end + + sub_test_case("Int32") do + def build_array + Arrow::Int32Array.new([-2147483648, nil, 2147483647]) + end + + def test_read + assert_equal([{"value" => [-2147483648, nil, 2147483647]}], + read) + end + end + + sub_test_case("UInt32") do + def build_array + Arrow::UInt32Array.new([0, nil, 4294967295]) + end + + def test_read + assert_equal([{"value" => [0, nil, 4294967295]}], + read) + end + end + + sub_test_case("Int64") do + def build_array + Arrow::Int64Array.new([ + -9223372036854775808, + nil, + 9223372036854775807 + ]) + end + + def test_read + assert_equal([ + { + "value" => [ + -9223372036854775808, + nil, + 9223372036854775807 + ] + } + ], + read) + end + end + + sub_test_case("UInt64") do + def build_array + Arrow::UInt64Array.new([0, nil, 18446744073709551615]) + end + + def test_read + assert_equal([{"value" => [0, nil, 18446744073709551615]}], + read) + end + end + + sub_test_case("Float32") do + def build_array + Arrow::FloatArray.new([-0.5, nil, 0.5]) + end + + def test_read + assert_equal([{"value" => [-0.5, nil, 0.5]}], + read) + end + end + + sub_test_case("Float64") do + def build_array + Arrow::DoubleArray.new([-0.5, nil, 0.5]) + end + + def test_read + assert_equal([{"value" => [-0.5, nil, 0.5]}], + read) + end + end + + sub_test_case("Date32") do + def setup(&block) + @date_2017_08_28 = 17406 + @date_2025_12_09 = 20431 + super(&block) + end + + def build_array + Arrow::Date32Array.new([@date_2017_08_28, nil, @date_2025_12_09]) + end + + def test_read + assert_equal([ + { + "value" => [ + @date_2017_08_28, + nil, + @date_2025_12_09, + ], + }, + ], + read) + end + end + + sub_test_case("Date64") do + def setup(&block) + @date_2017_08_28_00_00_00 = 1503878400000 + @date_2025_12_09_00_00_00 = 1765324800000 + super(&block) + end + + def build_array + Arrow::Date64Array.new([ + @date_2017_08_28_00_00_00, + nil, + @date_2025_12_09_00_00_00, + ]) + end + + def test_read + assert_equal([ + { + "value" => [ + @date_2017_08_28_00_00_00, + nil, + @date_2025_12_09_00_00_00, + ], + }, + ], + read) + end + end + + sub_test_case("Time32(:second)") do + def setup(&block) + @time_00_00_10 = 10 + @time_00_01_10 = 60 + 10 + super(&block) + end + + def build_array + Arrow::Time32Array.new(:second, [@time_00_00_10, nil, @time_00_01_10]) + end + + def test_read + assert_equal([ + { + "value" => [ + @time_00_00_10, + nil, + @time_00_01_10, + ], + }, + ], + read) + end + + def test_type + assert_equal(:second, type.unit) + end + end + + sub_test_case("Time32(:millisecond)") do + def setup(&block) + @time_00_00_10_000 = 10 * 1000 + @time_00_01_10_000 = (60 + 10) * 1000 + super(&block) + end + + def build_array + Arrow::Time32Array.new(:milli, + [ + @time_00_00_10_000, + nil, + @time_00_01_10_000, + ]) + end + + def test_read + assert_equal([ + { + "value" => [ + @time_00_00_10_000, + nil, + @time_00_01_10_000, + ], + }, + ], + read) + end + + def test_type + assert_equal(:millisecond, type.unit) + end + end + + sub_test_case("Time64(:microsecond)") do + def setup(&block) + @time_00_00_10_000_000 = 10 * 1_000_000 + @time_00_01_10_000_000 = (60 + 10) * 1_000_000 + super(&block) + end + + def build_array + Arrow::Time64Array.new(:micro, + [ + @time_00_00_10_000_000, + nil, + @time_00_01_10_000_000, + ]) + end + + def test_read + assert_equal([ + { + "value" => [ + @time_00_00_10_000_000, + nil, + @time_00_01_10_000_000, + ], + }, + ], + read) + end + + def test_type + assert_equal(:microsecond, type.unit) + end + end + + sub_test_case("Time64(:nanosecond)") do + def setup(&block) + @time_00_00_10_000_000_000 = 10 * 1_000_000_000 + @time_00_01_10_000_000_000 = (60 + 10) * 1_000_000_000 + super(&block) + end + + def build_array + Arrow::Time64Array.new(:nano, + [ + @time_00_00_10_000_000_000, + nil, + @time_00_01_10_000_000_000, + ]) + end + + def test_read + assert_equal([ + { + "value" => [ + @time_00_00_10_000_000_000, + nil, + @time_00_01_10_000_000_000, + ], + }, + ], + read) + end + + def test_type + assert_equal(:nanosecond, type.unit) + end + end + + sub_test_case("Timestamp(:second)") do + def setup(&block) + @timestamp_2019_11_18_00_09_11 = 1574003351 + @timestamp_2025_12_16_05_33_58 = 1765863238 + super(&block) + end + + def build_array + Arrow::TimestampArray.new(:second, + [ + @timestamp_2019_11_18_00_09_11, + nil, + @timestamp_2025_12_16_05_33_58, + ]) + end + + def test_read + assert_equal([ + { + "value" => [ + @timestamp_2019_11_18_00_09_11, + nil, + @timestamp_2025_12_16_05_33_58, + ], + }, + ], + read) + end + end + + sub_test_case("Timestamp(:millisecond)") do + def setup(&block) + @timestamp_2019_11_18_00_09_11 = 1574003351 * 1_000 + @timestamp_2025_12_16_05_33_58 = 1765863238 * 1_000 + super(&block) + end + + def build_array + Arrow::TimestampArray.new(:milli, + [ + @timestamp_2019_11_18_00_09_11, + nil, + @timestamp_2025_12_16_05_33_58, + ]) + end + + def test_read + assert_equal([ + { + "value" => [ + @timestamp_2019_11_18_00_09_11, + nil, + @timestamp_2025_12_16_05_33_58, + ], + }, + ], + read) + end + end + + sub_test_case("Timestamp(:microsecond)") do + def setup(&block) + @timestamp_2019_11_18_00_09_11 = 1574003351 * 1_000_000 + @timestamp_2025_12_16_05_33_58 = 1765863238 * 1_000_000 + super(&block) + end + + def build_array + Arrow::TimestampArray.new(:micro, + [ + @timestamp_2019_11_18_00_09_11, + nil, + @timestamp_2025_12_16_05_33_58, + ]) + end + + def test_read + assert_equal([ + { + "value" => [ + @timestamp_2019_11_18_00_09_11, + nil, + @timestamp_2025_12_16_05_33_58, + ], + }, + ], + read) + end + end + + sub_test_case("Timestamp(:nanosecond)") do + def setup(&block) + @timestamp_2019_11_18_00_09_11 = 1574003351 * 1_000_000_000 + @timestamp_2025_12_16_05_33_58 = 1765863238 * 1_000_000_000 + super(&block) + end + + def build_array + Arrow::TimestampArray.new(:nano, + [ + @timestamp_2019_11_18_00_09_11, + nil, + @timestamp_2025_12_16_05_33_58, + ]) + end + + def test_read + assert_equal([ + { + "value" => [ + @timestamp_2019_11_18_00_09_11, + nil, + @timestamp_2025_12_16_05_33_58, + ], + }, + ], + read) + end + end + + sub_test_case("Timestamp(timezone)") do + def setup(&block) + @timezone = "UTC" + @timestamp_2019_11_18_00_09_11 = 1574003351 + @timestamp_2025_12_16_05_33_58 = 1765863238 + super(&block) + end + + def build_array + data_type = Arrow::TimestampDataType.new(:second, @timezone) + Arrow::TimestampArray.new(data_type, + [ + @timestamp_2019_11_18_00_09_11, + nil, + @timestamp_2025_12_16_05_33_58, + ]) + end + + def test_type + assert_equal([:second, @timezone], + [type.unit, type.timezone]) + end + end + + sub_test_case("YearMonthInterval") do + def build_array + Arrow::MonthIntervalArray.new([0, nil, 100]) + end + + def test_read + assert_equal([{"value" => [0, nil, 100]}], + read) + end + end + + sub_test_case("DayTimeInterval") do + def build_array + Arrow::DayTimeIntervalArray.new([ + {day: 1, millisecond: 100}, + nil, + {day: 3, millisecond: 300}, + ]) + end + + def test_read + assert_equal([ + { + "value" => [ + [1, 100], + nil, + [3, 300], + ], + }, + ], + read) + end + end + + sub_test_case("MonthDayNanoInterval") do + def build_array + Arrow::MonthDayNanoIntervalArray.new([ + { + month: 1, + day: 1, + nanosecond: 100, + }, + nil, + { + month: 3, + day: 3, + nanosecond: 300, + }, + ]) + end + + def test_read + assert_equal([ + { + "value" => [ + [1, 1, 100], + nil, + [3, 3, 300], + ], + }, + ], + read) + end + end + + sub_test_case("Duration(:second)") do + def build_array + Arrow::DurationArray.new(:second, [0, nil, 100]) + end + + def test_read + assert_equal([{"value" => [0, nil, 100]}], + read) + end + + def test_type + assert_equal(:second, type.unit) + end + end + + sub_test_case("Duration(:millisecond)") do + def build_array + Arrow::DurationArray.new(:milli, [0, nil, 100_000]) + end + + def test_read + assert_equal([{"value" => [0, nil, 100_000]}], + read) + end + + def test_type + assert_equal(:millisecond, type.unit) + end + end + + sub_test_case("Duration(:microsecond)") do + def build_array + Arrow::DurationArray.new(:micro, [0, nil, 100_000_000]) + end + + def test_read + assert_equal([{"value" => [0, nil, 100_000_000]}], + read) + end + + def test_type + assert_equal(:microsecond, type.unit) + end + end + + sub_test_case("Duration(:nanosecond)") do + def build_array + Arrow::DurationArray.new(:nano, [0, nil, 100_000_000_000]) + end + + def test_read + assert_equal([{"value" => [0, nil, 100_000_000_000]}], + read) + end + + def test_type + assert_equal(:nanosecond, type.unit) + end + end + + sub_test_case("Binary") do + def build_array + Arrow::BinaryArray.new(["Hello".b, nil, "World".b]) + end + + def test_read + assert_equal([{"value" => ["Hello".b, nil, "World".b]}], + read) + end + end + + sub_test_case("LargeBinary") do + def build_array + Arrow::LargeBinaryArray.new(["Hello".b, nil, "World".b]) + end + + def test_read + assert_equal([{"value" => ["Hello".b, nil, "World".b]}], + read) + end + end + + sub_test_case("UTF8") do + def build_array + Arrow::StringArray.new(["Hello", nil, "World"]) + end + + def test_read + assert_equal([{"value" => ["Hello", nil, "World"]}], + read) + end + end + + sub_test_case("LargeUTF8") do + def build_array + Arrow::LargeStringArray.new(["Hello", nil, "World"]) + end + + def test_read + assert_equal([{"value" => ["Hello", nil, "World"]}], + read) + end + end + + sub_test_case("FixedSizeBinary") do + def build_array + data_type = Arrow::FixedSizeBinaryDataType.new(4) + Arrow::FixedSizeBinaryArray.new(data_type, + ["0124".b, nil, "abcd".b]) + end + + def test_read + assert_equal([{"value" => ["0124".b, nil, "abcd".b]}], + read) + end + end + + sub_test_case("List") do + def build_array + data_type = Arrow::ListDataType.new(name: "count", type: :int8) + Arrow::ListArray.new(data_type, [[-128, 127], nil, [-1, 0, 1]]) + end + + def test_read + assert_equal([{"value" => [[-128, 127], nil, [-1, 0, 1]]}], + read) + end + end + + sub_test_case("LargeList") do + def build_array + data_type = Arrow::LargeListDataType.new(name: "count", + type: :int8) + Arrow::LargeListArray.new(data_type, + [[-128, 127], nil, [-1, 0, 1]]) + end + + def test_read + assert_equal([ + { + "value" => [ + [-128, 127], + nil, + [-1, 0, 1], + ], + }, + ], + read) + end + end + + sub_test_case("Struct") do + def build_array + data_type = Arrow::StructDataType.new(count: :int8, + visible: :boolean) + Arrow::StructArray.new(data_type, + [[-128, nil], nil, [nil, true]]) + end + + def test_read + assert_equal([ + { + "value" => [ + [-128, nil], + nil, + [nil, true], + ], + }, + ], + read) + end + end + + sub_test_case("DenseUnion") do + def build_array + fields = [ + Arrow::Field.new("number", :int8), + Arrow::Field.new("text", :string), + ] + type_ids = [11, 13] + data_type = Arrow::DenseUnionDataType.new(fields, type_ids) + types = Arrow::Int8Array.new([11, 13, 11, 13, 13]) + value_offsets = Arrow::Int32Array.new([0, 0, 1, 1, 2]) + children = [ + Arrow::Int8Array.new([1, nil]), + Arrow::StringArray.new(["a", "b", "c"]) + ] + Arrow::DenseUnionArray.new(data_type, + types, + value_offsets, + children) + end + + def test_read + assert_equal([{"value" => [1, "a", nil, "b", "c"]}], + read) + end + end + + sub_test_case("SparseUnion") do + def build_array + fields = [ + Arrow::Field.new("number", :int8), + Arrow::Field.new("text", :string), + ] + type_ids = [11, 13] + data_type = Arrow::SparseUnionDataType.new(fields, type_ids) + types = Arrow::Int8Array.new([11, 13, 11, 13, 11]) + children = [ + Arrow::Int8Array.new([1, nil, nil, nil, 5]), + Arrow::StringArray.new([nil, "b", nil, "d", nil]) + ] + Arrow::SparseUnionArray.new(data_type, types, children) + end + + def test_read + assert_equal([{"value" => [1, "b", nil, "d", 5]}], + read) + end + end + + sub_test_case("Map") do + def build_array + data_type = Arrow::MapDataType.new(:string, :int8) + Arrow::MapArray.new(data_type, + [ + {"a" => -128, "b" => 127}, + nil, + {"c" => nil}, + ]) + end + + def test_read + assert_equal([ + { + "value" => [ + {"a" => -128, "b" => 127}, + nil, + {"c" => nil}, + ], + }, + ], + read) + end + end + end + end + end +end + +class TestFileReader < Test::Unit::TestCase + include ReaderTests + + def setup + Dir.mktmpdir do |tmp_dir| + table = Arrow::Table.new(value: build_array) + @path = File.join(tmp_dir, "data.arrow") + table.save(@path) + File.open(@path, "rb") do |input| + @reader = ArrowFormat::FileReader.new(input) + yield + @reader = nil + end + GC.start + end + end + + def read + @reader.to_a.collect do |record_batch| + record_batch.to_h.tap do |hash| + hash.each do |key, value| + hash[key] = value.to_a + end + end + end + end + + def type + @type ||= @reader.first.schema.fields[0].type + end +end + +class TestStreamingReader < Test::Unit::TestCase + include ReaderTests + + def setup + Dir.mktmpdir do |tmp_dir| + table = Arrow::Table.new(value: build_array) + @path = File.join(tmp_dir, "data.arrows") + table.save(@path) + File.open(@path, "rb") do |input| + @reader = ArrowFormat::StreamingReader.new(input) + yield + @reader = nil + end + GC.start + end + end + + def read + @reader.to_a.collect do |record_batch| + record_batch.to_h.tap do |hash| + hash.each do |key, value| + hash[key] = value.to_a + end + end + end + end + + def type + @type ||= @reader.first.schema.fields[0].type + end +end