Skip to content

Commit

Permalink
[SPARK-6994] Allow to fetch field values by name in sql.Row
Browse files Browse the repository at this point in the history
It looked weird that up to now there was no way in Spark's Scala API to access fields of `DataFrame/sql.Row` by name, only by their index.

This tries to solve this issue.

Author: vidmantas zemleris <[email protected]>

Closes #5573 from vidma/features/row-with-named-fields and squashes the following commits:

6145ae3 [vidmantas zemleris] [SPARK-6994][SQL] Allow to fetch field values by name on Row
9564ebb [vidmantas zemleris] [SPARK-6994][SQL] Add fieldIndex to schema (StructType)
  • Loading branch information
vidmantas zemleris authored and marmbrus committed Apr 21, 2015
1 parent 04bf34e commit 2e8c6ca
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 0 deletions.
32 changes: 32 additions & 0 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,38 @@ trait Row extends Serializable {
*/
def getAs[T](i: Int): T = apply(i).asInstanceOf[T]

/**
* Returns the value of a given fieldName.
*
* @throws UnsupportedOperationException when schema is not defined.
* @throws IllegalArgumentException when fieldName do not exist.
* @throws ClassCastException when data type does not match.
*/
def getAs[T](fieldName: String): T = getAs[T](fieldIndex(fieldName))

/**
* Returns the index of a given field name.
*
* @throws UnsupportedOperationException when schema is not defined.
* @throws IllegalArgumentException when fieldName do not exist.
*/
def fieldIndex(name: String): Int = {
throw new UnsupportedOperationException("fieldIndex on a Row without schema is undefined.")
}

/**
* Returns a Map(name -> value) for the requested fieldNames
*
* @throws UnsupportedOperationException when schema is not defined.
* @throws IllegalArgumentException when fieldName do not exist.
* @throws ClassCastException when data type does not match.
*/
def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = {
fieldNames.map { name =>
name -> getAs[T](name)
}.toMap
}

override def toString(): String = s"[${this.mkString(",")}]"

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)

/** No-arg constructor for serialization. */
protected def this() = this(null, null)

override def fieldIndex(name: String): Int = schema.fieldIndex(name)
}

class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru

private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap

/**
* Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
Expand All @@ -1049,6 +1050,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
StructType(fields.filter(f => names.contains(f.name)))
}

/**
* Returns index of a given field
*/
def fieldIndex(name: String): Int = {
nameToIndex.getOrElse(name,
throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
}

protected[sql] def toAttributes: Seq[AttributeReference] =
map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())

Expand Down
71 changes: 71 additions & 0 deletions sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema}
import org.apache.spark.sql.types._
import org.scalatest.{Matchers, FunSpec}

class RowTest extends FunSpec with Matchers {

val schema = StructType(
StructField("col1", StringType) ::
StructField("col2", StringType) ::
StructField("col3", IntegerType) :: Nil)
val values = Array("value1", "value2", 1)

val sampleRow: Row = new GenericRowWithSchema(values, schema)
val noSchemaRow: Row = new GenericRow(values)

describe("Row (without schema)") {
it("throws an exception when accessing by fieldName") {
intercept[UnsupportedOperationException] {
noSchemaRow.fieldIndex("col1")
}
intercept[UnsupportedOperationException] {
noSchemaRow.getAs("col1")
}
}
}

describe("Row (with schema)") {
it("fieldIndex(name) returns field index") {
sampleRow.fieldIndex("col1") shouldBe 0
sampleRow.fieldIndex("col3") shouldBe 2
}

it("getAs[T] retrieves a value by fieldname") {
sampleRow.getAs[String]("col1") shouldBe "value1"
sampleRow.getAs[Int]("col3") shouldBe 1
}

it("Accessing non existent field throws an exception") {
intercept[IllegalArgumentException] {
sampleRow.getAs[String]("non_existent")
}
}

it("getValuesMap() retrieves values of multiple fields as a Map(field -> value)") {
val expected = Map(
"col1" -> "value1",
"col2" -> "value2"
)
sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ class DataTypeSuite extends FunSuite {
}
}

test("extract field index from a StructType") {
val struct = StructType(
StructField("a", LongType) ::
StructField("b", FloatType) :: Nil)

assert(struct.fieldIndex("a") === 0)
assert(struct.fieldIndex("b") === 1)

intercept[IllegalArgumentException] {
struct.fieldIndex("non_existent")
}
}

def checkDataTypeJsonRepr(dataType: DataType): Unit = {
test(s"JSON - $dataType") {
assert(DataType.fromJson(dataType.json) === dataType)
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,14 @@ class RowSuite extends FunSuite {
val de = instance.deserialize(ser).asInstanceOf[Row]
assert(de === row)
}

test("get values by field name on Row created via .toDF") {
val row = Seq((1, Seq(1))).toDF("a", "b").first()
assert(row.getAs[Int]("a") === 1)
assert(row.getAs[Seq[Int]]("b") === Seq(1))

intercept[IllegalArgumentException]{
row.getAs[Int]("c")
}
}
}

0 comments on commit 2e8c6ca

Please sign in to comment.