diff --git a/xml-stdlib/shared/src/main/scala/rapture/xml-stdlib/ast.scala b/xml-stdlib/shared/src/main/scala/rapture/xml-stdlib/ast.scala index 03cc9eb..ffea969 100644 --- a/xml-stdlib/shared/src/main/scala/rapture/xml-stdlib/ast.scala +++ b/xml-stdlib/shared/src/main/scala/rapture/xml-stdlib/ast.scala @@ -25,6 +25,7 @@ private[stdlib] object StdlibAst extends XmlBufferAst { override def dereferenceObject(obj: Any, element: String): Any = obj match { case n: Node if n.child.exists(_.label == element) => n \ element + case ns: NodeSeq if ns.exists(_.label == element) => ns.filter(_.label == element) case ns: NodeSeq if ns.exists(_.child.exists(_.label == element)) => ns \ element case _ => throw MissingValueException() } diff --git a/xml-test/shared/src/test/scala/rapture/xml-test/tests.scala b/xml-test/shared/src/test/scala/rapture/xml-test/tests.scala index 7a24abd..24508df 100644 --- a/xml-test/shared/src/test/scala/rapture/xml-test/tests.scala +++ b/xml-test/shared/src/test/scala/rapture/xml-test/tests.scala @@ -183,6 +183,11 @@ abstract class XmlTests(ast: XmlAst, parser: Parser[String, XmlAst]) extends Tes Xml(1648).toString } returns "xml\"\"\"1648\"\"\"" + val `Extract top level xml elements` = test { + val fooXML = Xml(Foo("Joe", 10)) + fooXML.alpha.as[String] + } returns "Joe" + /*val `Serialize array` = test { Json(List(1, 2, 3)).toString } returns "123" diff --git a/xml/shared/src/main/scala/rapture/xml/ast.scala b/xml/shared/src/main/scala/rapture/xml/ast.scala index b6a61a6..41dfa60 100644 --- a/xml/shared/src/main/scala/rapture/xml/ast.scala +++ b/xml/shared/src/main/scala/rapture/xml/ast.scala @@ -21,15 +21,15 @@ import scala.util._ trait XmlAst { - /** Dereferences the named element within the JSON object. */ + /** Dereferences the named element within the XML object. */ def dereferenceObject(obj: Any, element: String): Any = getObject(obj)(element) - /** Returns at `Iterator[String]` over the names of the elements in the JSON object. */ + /** Returns at `Iterator[String]` over the names of the elements in the XML object. */ def getKeys(obj: Any): Iterator[String] = getObject(obj).keys.iterator - /** Gets the indexed element from the parsed JSON array. */ + /** Gets the indexed element from the parsed XML array. */ def dereferenceArray(array: Any, element: Int): Any = getArray(array)(element) @@ -41,7 +41,7 @@ trait XmlAst { def isNull(any: Any): Boolean - /** Extracts a JSON object as a `Map[String, Any]` from the parsed JSON. */ + /** Extracts a XML object as a `Map[String, Any]` from the parsed XML. */ def getObject(obj: Any): Map[String, Any] def getChildren(obj: Any): Seq[Any] = { @@ -51,7 +51,7 @@ trait XmlAst { def fromObject(obj: Map[String, Any]): Any - /** Extracts a JSON array as a `Seq[Any]` from the parsed JSON. */ + /** Extracts a XML array as a `Seq[Any]` from the parsed XML. */ def getArray(array: Any): Seq[Any] def fromArray(array: Seq[Any]): Any