From 1462687d310d61ca4a3fd1787f9442e1b2b83cb1 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 10 May 2022 04:03:56 +0000 Subject: [PATCH 1/3] add phi design doc --- docs/design/phi/design.md | 1749 ++++++++++++++++++++++ docs/design/phi/images/kernel-design.png | Bin 0 -> 102777 bytes docs/design/phi/images/tensor-design.png | Bin 0 -> 134604 bytes 3 files changed, 1749 insertions(+) create mode 100644 docs/design/phi/design.md create mode 100644 docs/design/phi/images/kernel-design.png create mode 100644 docs/design/phi/images/tensor-design.png diff --git a/docs/design/phi/design.md b/docs/design/phi/design.md new file mode 100644 index 00000000000..efb363f8ce2 --- /dev/null +++ b/docs/design/phi/design.md @@ -0,0 +1,1749 @@ + +飞桨高可复用算子库 PHI (Paddle HIgh reusability operator library),或者我们也成为函数式算子库,支持组合式算子功能复用、Primitive算子内核复用、插件式硬件加速库复用。针对飞桨框架原算子库存在的算子接口不清晰、算子复用成本较高、调用性能不够快的问题,我们重构了飞桨框架的算子库,设计了灵活、高效的函数式算子库 Phi,可以通过对函数式算子接口组合调用的方式实现新算子。新算子库提供了 200 余个跟 python 开发接口保持一致的 C++ 运算类 API,以及近500个可供组合调用的前、反向函数式算子内核 Kernel,可大幅降低框架原生算子和自定义算子的开发成本。新算子库支持Primitive API方式开发算子内核,可支持不同硬件(比如GPU和XPU)的算子内核复用。新算子库支持以插件方式接入硬件(比如NPU)的加速库,实现低成本复用硬件加速库。 + +> 本文档撰写于phi架构基本成型之时(2022年2月),仅代表该时间点的基本设计形态,可能和最新形态有细微差别;此外,在2.3版本发布的phi算子库仍然处于初期形态,后续仍然需要持续建设并完善,设计上也有可能调整。 + +# 一、背景与目标 + +> 介绍设计并建设 phi 要解决的问题 + + +最初 phi 项目的启动仅是为了优化飞桨动态图调度开销、并提升Kernel开发的复用能力而提出来的,但后续决定借此机会,建立能够同时在训练和推理场景(包括服务器端和移动端场景)中使用的“训推一体”算子库,长远上降低 paddle 生态中各基础设施开发及维护算子的成本,逐渐扩充了项目的目标范围,目前 phi 已经承载了多维度的意义。 + +> 关于算子库的命名,开发过程中有过迭代:初期算子库目录名为 pten ,意为paddle Tensor运算库 (Paddle Tensor Operation Library),因此一些历史 PR 以PTen为前缀,后期认为该名称表述范围不够准确,因此更名为 phi + +## 1.1 背景问题 + +具体地,phi 算子库项目,承载着解决 Paddle 以下问题的期望: + +### 1.1.1 Op&OpKernel之间可复用性差,冗余代码较多 + +2.3版本之前,Paddle中的Operator(后续简称Op)之间的可复用性比较差,仅在少数的反向Op中,通过在GradOpMaker实现中调用SetType复用了一些简单的运算,大部分本身可以复用已有Op实现的情况,代码都是copy重写的。 + +可复用性差的根本原因还是原先Op体系设计导致的: + +1. 当一个Op去复用另一个Op的`Opkernel::Compute`方法,都需要先构造一个`ExecutionContext`,复用上是比较繁琐的 + + - 如果能直接调用一个函数形式的Kernel,就会方便很多 + +2. 由于额外的数据结构构造及独立Op调度引入了开销,从计算性能的角度考虑,复用Op不如直接把计算代码copy过来,导致我们逐渐抛弃了早期反向Op复用前向Op的原则,开始为每个反向Op单独实现Kernel + + - 只有Op之前复用的开销足够小,复用已有Op实现新Op才有可能被大范围推广 + +### 1.1.2 执行调度的简洁性与细粒度化 + +#### 1.1.2.1 动态图 + +Paddle 2.0发布之后,多次收到内外部用户反馈动态图在小模型CPU执行场景下与竞品在性能上有数倍的差距。 + +这个问题的主要原因是:Padddle动态图C++端的执行路径比较冗长,调度开销比较重,这和动态图早期设计兼容静态图,继承了静态图Op的许多对象构造过程有关 + +- 问题issue:https://github.com/PaddlePaddle/Paddle/issues/28774 + +因此,动态图需要升级为基于函数的调度架构,抛开原先复杂的Op体系,才能解决这个问题,这依赖于OpKernel改为函数式的写法。 + +#### 1.1.2.2 静态图 + IR + +我们目前的静态图还不够“静态”,目前静态图仍然有许多运行时动态选择的逻辑,例如,运行时选择OpKernel,运行时判断是否要进行跨设备数据拷贝等等,但这些其实可以在静态图模型组网编译期间就确定下来,将执行过程确定为一系列OpKernel的执行,不再做动态的判断选择,从而进一步提升执行效率。 + +而这些依赖于OpKernel本身的细粒度化,将现有复杂的大OpKernel解耦成具体场景、具体设备的小Kernel,才能支持这样的调度。 + +### 1.1.3 自定义算子的易用性提升需求 + +2021年初上线的新自定义C++外部算子体系,在接口与函数编写的层面上,用法已经比较直观了,但是因为我们缺少基本运算的C++ API体系,事实上,在实现具体的自定义Op运算逻辑时,一些基础的加减乘除及矩阵运算都仍然需要重新实现一遍,不能复用Paddle已有的、经过优化的基础运算,因此一些复杂运算的外部开发成本仍然是比较高的。而要想复用Paddle内部的基础运算,有赖于的Op体系升级为函数式,并整理对应的C++ API体系才能解决。 + +### 1.1.4 共建训推一体算子库,降低推理算子维护成本 + +长久以来,由于paddle主框架和paddle-lite的算子是分开维护的,paddle新增的算子,lite需要的话,就要手动在lite中重新实现一遍,而且当主框架算子升级,lite又没有及时感知到,会直接导致推理模型在lite执行时出现bug,这维护成本是很高的,只有统一算子库,仅维护一份代码,才能长久解决这个问题。 + +因此,本次函数式算子库会由训练和推理共同建设,计算库整理完成后,作为独立的编译组件和底层基础设施(目前还没有独立拆分出来),能够同时服务于训练、预测以及Lite等执行体系。 + +### 1.1.5 推理新Runtime设计infrt的适配 + +推理设计了新的runtime infrt,预计要统一paddle-inference和paddle-lite的执行,将来需要直接调用本次共建的phi算子库中的算子,因此在设计时需要考虑对infrt的适配。 + +### 1.1.6 Op及Kernel参数规范化 + +python 2.0 API项目规范了Paddle Python端API的参数列表,使其变得简洁、易用,但是限于当时的情况,Op层面的参数列表并没有规范化,因此会有不少早期开发的算子和Python API参数相差较多,例如conv op这种,python API仅有7个参数,但C++ Op却有30+参数的分裂情况,而API和Op本质上是同一层概念,都是对一个运算的描述,参数应该是一致的。推理为了解决此问题,推动了算子定义增强项目,为部分不需要的参数添加了AsExtra以及AsQuant的声明,但并未从根本上解决问题,这也是phi算子库构建希望重点去解决的。 + +我们希望能做到,Python API -> Op(C++ API) -> Kernel API三层参数一致,使整体架构清晰,每一层复用也清晰,一套python官方文档,基本能够满足三层API的共同参考需求,不再着重维护额外的文档体系,降低维护成本。 + +## 1.2 目标及范围 + +- 总体目标:飞桨核心框架复用同一函数式算子库,基础数据结构Tensor具备良好的可扩展性,从根本上做到训练推理协同一致、基础组件稳定可靠、增量开发体验良好。 + +- 目标范围: + + - phi算子库初期构建更关注Kernel“迁移”,人力因素,原Kernel逻辑迁移时暂不强制升级为“组合式”写法,前反向Kernel均如此 + - phi算子库初期提供的"组合式Kernel二次开发"能力面向后续增量的新算子使用,已有算子仍然保持其原先的编码实现,降低迁移成本 + - phi算子库初期提供的“新硬件扩展能力”仅在新硬件自身范围内提供,比如XPU已经实现了50个Kernel,后续其可以基于50个Kernel去组合新的Kernel,但这仅限于XPU范围内,其实现不和CPU、CUDA等实现通用 + - phi算子库项目重点关注“Kernel函数化&Op规范化”的工作,Kernel改为函数式,C++API与Op命名及参数列表在尽可能确保兼容性的前提下与逐渐规范化为与Python API一致 + + +# 二、设计概览 + +## 2.1 命名及位置 + +飞桨高可复用算子库 (Paddle HIgh reusability operator library),简称 PHI(phi),phi代码目录在paddle目录下,和fluid平级,而不是放在fluid目录下,这样放置的原因是:phi是一个由fluid,lite,infrt等多种上层runtime共同调用的基础组件,后续会作为单独的编译的动态库存在,不适合作为fluid的子模块。 + +## 2.2 目录结构 + +### 2.2.1 目录结构设计需满足的需求 + +训练和推理对算子库目录的清晰度也有诸多诉求: + +- 在目录设计上支持算子库的各种拆分编译需求,包括 + + - 按运算设备拆分编译 + - 例如:仅编译cpu的,或者仅编译cuda的 + - 按训练和推理场景拆分编译 + - 例如:推理不编译反向相关kernel,也不编译带有Intermediate输出的前向kernel + - 按移动端设备实际使用算子精准裁剪编译(目前尚未支持) + - 例如:一个模型只用了add和mul,极致情况下应该能裁到仅剩2个kernel +- 长线上支持良好的kernel复用实现需求 + - 解释:kernel复用实现时,能否通过简单的include引入对应函数,不会因为目录过于复杂而找不到复用的kernel + +- 长线上支持跨设备kernel的写法统一需求,并且直观易用,不引入不必要的模板参数 + - 解释:算子库下层还有Kernel Primitive API模块,其长线愿景是每个运算,只要一个kernel,能够适应多种设备,真正区分设备的代码,仅在Kernel Primitive API实现中;不希望未来的kernel在复用时从传入较复杂的模板参数,需要尽可能限制地简洁一些 + +- 易用性上,开发者能精准理解自己新增Kernel应该放到什么位置,无歧义 + - 解释:开发者新增一个API,不会困惑自己应该将对应kernel放在那个目录,也不会出现不同的人对于同一个kernel应该放在什么位置出现二义性的理解 + +- 不引入大量的重复目录设计 + - 解释:概念拆分是需要的,但也要有边界,避免在多个目录下有命名相同的子目录,容易混乱,比如不能cpu下面有eigen, funcs, math等,gpu下面也有。新算子库的目录设计以根据设备拆分为主,其他层次的目录拆分尽可能弱化,比如尽量不根据功能拆分,尽量不根据领域拆分等 + +- 不造成迁移时的文件数目膨胀 + - 解释:不能因为kernel设备拆分,导致kernel实现文件大规模增多 + +- 不引入层级过深的目录设计 + - 解释:目录层级不应过深,理解和维护成本都较高 + +- 不引入过高的迁移成本 + - 解释:迁移kernel时,不能要求对kernel本身做太多改动和拆分,否则迁移成本太高 + +### 2.2.2 具体目录设计 + +#### 2.2.2.1 一级目录 + +``` +paddle/phi +./api (对外暴露的高层API及其实现) + ./include(对外暴露的高层API头文件) + ./lib(对外暴露API的实现) +./common (内外部均会使用到的基础数据结构) +./core (基础组件,比如基础Tensor相关接口,kernel注册接口,管理单元等) +./backends (各设备及后端的基础组件,下设cpu,gpu等后端目录) +./infermeta (shape、dtype、layout等meta信息的推导函数) +./kernels (各设备及后端的kernel实现) +./ops (各op的定义,后续采取自动生成的方式完成大部分工作,目前仅有兼容用的代码) +./tests (单元测试) +``` + +部分目录结构说明: + +- `api`:API模块,面向外部用户 + - 直接使用类python的C++ Tensor计算 API,和Python端形式高度一致 + - 该部分可能反向依赖框架的DeviceContextPool等实现,所以单独管理 + - 在该类API上,训练和预测也可能是不同的 +- `common`:phi内部及phi api目录均要使用的数据结构,这些数据结构既不属于phi core,也不属于api目录 +- `core`:phi内部会有一些自己需要的,公用的模块实现,比如基础DenseTensor、,kernel注册及管理模块 +- `backends`:backends中组织后续需要为各个后端的新增的数据结构,比如CPUContext、GPUContext等 + - core中放置对于算子库来讲通用的基础数据结构,而特定后端的专用数据结构不放在core中,且依赖关系严格保证backends依赖core,但core不能依赖backends + - 例1:Context如果有基类,则在core中,而继承的CPUContext在backends/cpu中,GPUContext在baackends/gpu中 + - 例2:TensorBase在core中,DenseTensor给多数设备使用,也在core中,如果有MKLDNNTensor的话,因为它只给mkldnn用,应该在backends/dnnl中 +- `infermeta`: infermeta函数的整理位置,infermeta函数相当于infershape+inferdtype+inferlayout等 +- `kernels`:各设备相关kernels + - `cpu, gpu, ...` +- `ops`: ops中组织新形式的Op定义、以及兼容原有op的一些组件 + + +#### 2.2.2.2 Kernels目录 + +``` +paddle/phi/kernels +./ (放置设备无关的kernel声明和实现) +./cpu(仅放置cpu后端的kernel实现) +./gpu +./xpu +./dnnl +./gpudnn +./impl (考虑到现状,放置原先Kernel在CPU和GPU或其他设备一致的实现,便于复用) +./funcs(放置原fluid operators下一些支持多设备的functor和funcs) +./primitive(放置Kernel Primitive API的基础实现) +... +``` + +目录结构说明如下: + +- kernels下主目录,放置设备无关的kernel.h和kernel.cc,原则上每个kernel一个.h和.cc + - 例如一个kernel是使用Primitive api实现的,或者是复用其他基础kernel实现的,那么不论在什么设备上,应该都只有一种实现,所以它的声明和实现均直接放置到kernels目录下即可(这是将来的理想状态) + - 目前我们大部分kernel都不具备跨设备实现统一的特征,但是kernel的输入参数返回值除了DeviceContext之外,应该是一致的,所以kernel参数声明头文件还放到主目录下(和原先的设计保持一致,DeviceContext和T作为模板参数),各设备的函数实现在相应的设备文件夹中 + - 注意,这里跨设备实现统一,并不是指一个kernel的CPU和GPU实现就算统一了,而是在所有设备的实现都一样,目前至少包括CPU,GPU,XPU,MKLDNN,GPUDNN等 + - 反向kernel如果不需要支持裁剪,可以做适当归并(但如果要为支持端侧训练留可能性,反向kernel可能也是裁剪的潜在目标) +- kernels下一级子目录,原则上按照backend分类按需新建,仅保留两个特殊的目录: + - funcs:为了兼容原先fluid operators中functor和function设计保留的目录,放置支持多种后端的function和functor,还按照原先的一个头文件,多个.cc(u)的方式组织(这部分代码在将来可能被移除,因为会逐渐被Kernel Primirive API及Kernel间复用替代,这里不做过度设计) + - 例1:一个公共函数XXXFunction在reduce CPU和reduce CUDA的kernel实现中都被调用,并且reduce CPU和reduce GPU的kernel实现是不一样的,那么这个XXXFunction应该在funcs目录中 + - primitive:Kernel Primitive API,多设备统一kernel实现的一些基础工具 + - impl:paddle目前的op kernel实现,有很多仍然是CPU和GPU复用同一份代码的,在大量的xx_op.h,这部分代码,不适合放在cpu或者gpu目录中,也不适合放在funcs目录中(会导致funcs目录中最终放置了相当一部分kernel实现,过于臃肿且混乱,funcs目录的定位是放置原先operators/math目录下那样的工具functor和function),也不适合放到kernels根目录下(并不是真正设备无关的实现,仅是cpu和gpu共用的实现),因此为了使这部分代码迁移时不需要做过多考虑,并且放置的位置也相对符合其实现性质,创建了impl这个目录 + - impl目录下,仅放置跨部分设备实现一致的kernel函数,均为头文件,命名均以xxx_kernel_impl.h为后缀 + - 例如:scale,fill_constant,fill_any_like这些kernel均属于此类情况 +- kernel迁移过来之后,首先创建对应kenrel头文件直接放置到kernels的根目录中,各后端的kernel实现放在相应的设备文件夹中 + - 可参考原先op的归并程度,如matmul原先是单独的.h/.cc,那移过来之后保持,但activation相关的基本写在一个.h/.cc,移过来也仍然保持归并(后续有必要再进一步拆分) + - 例1:原先cast op的Kernel在cast_op.h中,迁移过来之后在根目录创建cast_kernel.h,cast_kernel.cc/cu根据使用的后端放到对应的目录,即cast_kernel.cc放置到cpu中,cast_kernel.cu放置到gpu中 + - 例2:原先scale op的kernel使用eigen实现,CPU和GPU实现一致,迁移过来之后,公共实现应该在impl中的scale_kernel_impl.h中,公共头文件在kernels根目录下的scale_kernel.h中,scale_kernel.cc在cpu中,scale_kernel.cu在gpu中 +- 迁移时,只有本kernel用到的辅助函数,一律和kernel实现放到同一个backend文件中,创建.h管理代码,不再单独在别处整理代码,除非这些辅助的函数实现是有多处使用的 + - 即使有多处调用,如果仍然限于同一设备,直接建头文件放到同一个目录下 +- 反向kernel与前向kernel实现放置在不同的文件中,文件后缀采用``*_grad_kernel.*``,便于cmake分离编译 + - 不再为反向kernel单独创建目录,否则反向kernel目录下还要创建cpu/gpu等目录 + - 二阶导、三阶导的实现统一也放到grad kernel实现文件中 + +- 为什么目录名叫`gpu`而不是`cuda`和`hip`? + - cuda和hip代码重复度非常高,统一实现维护成本较低 + + +## 2.3 核心组件 + +### 2.3.1 公共基础数据结构 + +#### 2.3.1.1 Backend + +``` +/** + * [ Why need Backend? ] + * + * Backend not only means place. Backend is a superset of place. + * + * Place cannot indicate the difference in calculation methods on the device, + * but in order to make the boundary of the kernel clearer and the function + * more specific, we need to distinguish the calculation method. + * + * Such as the kernel for CPU device, it can be a native CPU kernel, + * or a kernel implemented by MKLDNN library. + * + * Note(chenweihang): HIP is not needed now, we can added it if needed + * in the future + */ +enum class Backend : uint8_t { + UNDEFINED = 0, + + // basic kernel backend + CPU, + + // various acceleration devices' backends + GPU, + XPU, // XPU currently does not exist at the same time as CUDA + NPU, // NPU currently does not exist at the same time as CUDA + + // the third library backend + MKLDNN, + GPUDNN, + + // end of backend types + NUM_BACKENDS, + + /** + * [ Why we need ALL in baisc kernel key member? ] + * + * For Tensor, ALL represents an illegal Backend, but for Kernel, some + * kernels may be device-independent by nature, such as reshape; and when + * and some kernels are also device-independent when implemented based on + * primitive API. + * + * In this case, we need to provide a more concise registration method, + * instead of registering the kernels for each device with almost + * repetitive code, we need one registration covers all situations, + * so if we provide the ALL field with Register the kernel in this statement. + * + * Of course, we have also considered solving this problem through different + * named macros, for example, if we define + * + * PT_REGISTER_KERNEL_FOR_ALL_BACKEND + * + * Based on this design pattern, the dtype and layout also have the same + * requirements, this cause we need to define a series of macros + * + * PT_REGISTER_KERNEL_FOR_ALL_DTYPE + * PT_REGISTER_KERNEL_FOR_ALL_LAYOUT + * PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT + * PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_DTYPE + * PT_REGISTER_KERNEL_FOR_ALL_LAYOUT_AND_DTYPE + * PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT_AND_DTYPE + * + * It makes the system of registering macros more complicated, we think + * this is not a simple design, so we still adopt the design of providing + * the ALL field. + * + * Note: ALL_BACKEND only used for Kernel registration and selection + */ + ALL_BACKEND = UNDEFINED, +}; +``` + +#### 2.3.1.2 DataLayout + +``` +// Note: Here the DataLayout is public api for external users, the prefix `k` +// maybe confuse users, so we use all uppercase names +enum class DataLayout { + UNDEFINED = 0, + // TODO(chenweihang): keep ANY for compatibility, remove it later + ANY = UNDEFINED, + NHWC, + NCHW, + MKLDNN, + NUM_DATA_LAYOUTS, + // See Note [ Why we need ALL in basic kernel key member? ] + ALL_LAYOUT = UNDEFINED, + // Note: Unify phi DataLayout and fluid::framework::DataLayout, + // for compatible with fluid DataLayout, here need prefix `k` + // Note: The original `kAnyLayout (enum value 2)` is a strange design. + // `kAnyLayout` originally cannot represent any kind of Layout, + // at the same time, it can also represent any Layout. + // Strictly, it means "default" or "undefined" layout, + // and should not be mixed with other meaningful layouts. + kAnyLayout = ANY, + kNHWC = NHWC, + kNCHW = NCHW, + kMKLDNN = MKLDNN, // all layouts supported by MKLDNN internally +}; +``` + +#### 2.3.1.3 DataType + +``` +enum class DataType { + UNDEFINED = 0, + BOOL, + INT8, // Char + UINT8, // BYte + INT16, + INT32, + UINT32, + INT64, + UINT64, + BFLOAT16, + FLOAT16, + UINT16, + FLOAT32, + FLOAT64, + COMPLEX64, + COMPLEX128, + NUM_DATA_TYPES, + // See Note [ Why we need ALL in baisc kernel key member? ] + ALL_DTYPE = UNDEFINED, +}; +``` + +- 这里什么不使用原先fluid的VarType? + - 理由1:原先fluid的DataType和VarType是同级概念,设计是比较混乱的,例如LoDTensor和FLOAT32是同级概念,但这两者显然不是的,我们不希望继承原先有明显缺陷的设计 + - 理由2:和fluid解耦依赖,便于后续phi可以独立编译 + +#### 2.3.1.4 Scalar + +Scalar (标量)用来统一表示具有不同基础数据类型(float, double, int, bool等)的变量。(目前也支持表示元素数量为1的Tensor标量,但后续可能会放弃该功能的支持) + +以`ScaleKernel`为例,其中的`scale`参数可以传入int,float,double等普通数据类型。如果不使用`Scalar`来表示的话,需要为每种数据类型单独创建一个函数接口,这样会大大增加开发Kernel的代码量,因此`Scalar`主要应用在具有不同数据类型的同一参数上,可以避免该场景下需要编写多个重载函数的问题。 + +``` +template +void ScaleKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& scale, + float bias, + bool bias_after_scale, + DenseTensor* out); +``` + +#### 2.3.1.5 IntArray + +IntArray 是一个整数类型数组,可以由`vector`,`Tensor`以及`vector`进行构造,目前主要用来表示shape,index以及aixs等维度索引变量。 + +以FullKernel为例,其中的shape参数用来表示返回Tensor的维度信息(如[2,8,8]),在调用FullKernel时该项参数传入`vector`,`Tensor`和`vector`类型的变量兼可完成调用。使用IntArray避免了每种shape类型单独编写一个重载函数的问题。 + +``` +template +void FullKernel(const Context& dev_ctx, + const IntArray& shape, + const Scalar& val, + DenseTensor* out); +``` + +### 2.3.2 Tensor体系 + +整体设计类图如下 + +![tensor-design.png](./images/tensor-design.png) + + +以下依次进行介绍。 + +#### 2.3.2.1 API Tensor接口 + +- 最上层是API级别的Tensor接口封装,里面包含两个指针成员,TensorBase和AbstractAutogradMeta。 + - 两个成员均使用了Interface设计,不会依赖于真实的Tensor和Autograd实现 + - AutogradMeta仅在动态图API级别的Tensor中有意义,在具体的kernel计算中,不会被使用到,所以将其放到最上层的Tensor接口中 + - 另外,这样设计也是为了方便数据共享,并且减少拷贝开销 + - 当一个Tensor赋值给另一个Tensor,或者Tensor作为函数返回值时,实际上只会拷贝指针,不会产生真实的数据拷贝 + +- 最上层C++ Tensor与Python端Tensor扮演类似的角色,在接口设计上尽可能与Python端保持一致 + - 包含基础的Tensor属性访问及数据访问方法 + - shape, place, dtype, data + - 包含动态图Tensor需要的autograd方法 + - gradient, backward + - 包含Tensor间的转换方法 + - cpu, gpu, xpu等 + - 包含tensor相关的计算方法(暂未添加) + - `paddle.tensor` 模块下所有方法 + +- 编译解耦: + + - 这里带有的autograd信息,只是一个指针索引,默认为空 + - `std::unique_ptr autograd_meta_ = nullptr;` + - 而这里的AbstractAutogradMeta是一个抽象类接口,不会依赖autograd的任何模块,因此不会影响 phi 的独立编译,同时又兼顾了动态图Tensor需要持有反向信息的需求 + +- 这里的AutogradMeta仅在动态图场景中才会设置,不需要的场景,比如静态图内就仅仅是个空指针而已 + +Tensor设备的判断及转换 + +- Tensor的设备及类型判断 + +``` +bool is_cpu() const; +bool is_gpu() const; +bool is_xpu() const; +bool is_dense_tensor() const; +bool is_selected_rows() const; +bool is_opencl() const; // 待添加 +bool is_metal() const; // 待添加 +``` + +- Tensor间类型转换,通过与Python端一致的API实现(待添加) + +``` +Tensor cpu() const; // 转换为cpu tensor +Tensor gpu() const; // 转换为gpu tensor +Tensor xpu() const; +Tensor mkldnn() const; +``` + +- 这个转换的过程可能是cast,也可能是copy + - 如果不需要进行数据拷贝,就是cast + - 如果需要进行数据拷贝,就是copy + - 转换通过函数式kernel去实现 + +- 在API场景中的使用 + - 用户在完整训练场景中,使用API的时候,最初读入的数据一般是从磁盘读入,先放入CPU,然后再转换到具体执行设备上,比如DataLoader + +#### 2.3.2.2 TensorBase + +- Tensor实现的接口类,接口中仅包含必要的纯虚Tensor方法,不包含有实际含义的成员,这里的方法在开发过程中也要严格控制 + +- 为什么要在这一层用抽象类设计? + - 一方面是为了隔离Tensor API与Tensor具体实现,不产生过多依赖,如果将来Tensor API需要重新设计,或者说需要放弃掉autograd信息,只需要重新设计一个Tensor API即可,对于底层Tensor的实现几乎没有影响 + - 另一方面是为了给异构化的Tensor保留充足的扩展空间,框架API层仅需要一个Tensor数据结构即可,不需要再暴露多种数据结构设计,这里其实做了一个大范围定义,框架内所有数据结构均是Tensor + - 对于内存布局基本一致,或者说Tensor描述基本一致的实现,可以基于一种DenseTensor的实现去继承 + - 如果是异构化程度高的Tensor,可以直接从Interface继承去实现新的Tensor分支,比如只有一个Object的Tensor,确保在Tensor扩展灵活性上不会出现瓶颈 + +#### 2.3.3.3 DenseTensor、SparseTensor + +- 对应原fluid内的LoDTensor类,是Tensor的基类实现,Allocation就是现有Allocation,包含现有Tensor的基础成员 +- SparseCsrTensor、SparseCooTensor为新设计的稀疏tensor类型,详见代码实现 + +> 为了兼容原先框架调度及算子,SelectedRows我们也迁移过来作为一种基础Tensor类型,后续如果能够被新的稀疏Tensor替代,长期会移除 + +#### 2.3.3.4 其他异构Tensor + +- 如果现有Allocation的描述无法满足一些第三方库对于Tensor内存的描述需求,可以继承TensorBase之后,使用新的Allocation实现 +- 而这种Tensor本质上没有脱离通用Tensor的范畴,只是访存方式有所区别,其他的TensorMeta信息,它仍然是需要的 +- 可以自行定义特殊的TensorAllocation描述类,去构建自定义的Tensor,例如MetalTensor + +``` +template +class SpatialTensor : public TensorBase { + public: + SpatialTensor(std::shared_ptr allocation, + std::unique_ptr meta) + : allocation_(std::move(allocation)), + meta_(std::move(meta)) {} + + private: + std::shared_ptr allocation_; + std::unique_ptr meta_; +}; + +template +class MetalTensor : public SpatialTensor {}; + +template +class OpenCLTensor : public SpatialTensor {}; +``` + +- 通过这种方式,无论Tensor的需求如何特殊,均可以在对外API保持一致的前提下进行内部适配 + +其他高自由度Tensor继承:直接继承TensorBase + +- TensorBase是抽象类,为具体Tensor的描述留了较大的空间,如果传统Tensor的描述无法满足需求,可以设计特异化的Tensor实现 + + +### 2.3.3 C++ API + +#### 2.3.3.1 C++ API形式 + +> 本节要点: +> 1. C++ API与Python 2.0 API对应,函数名、参数名、参数顺序、返回值均一致 + +经过调研,我们发现只有框架产品在设计时考虑了C++ API易用性层面的问题的。出于长期考虑,我们若想要吸引更多的开发者共建飞桨生态,提供规范易用的C++ API体系也是十分重要的。同时,Python 2.0 API项目为C++ API奠定了良好的参考基础,我们可以直接继承其成果。 + +因此,目前我们期望Tensor计算库的C++ API声明形式如下: + +``` +Tensor mean(const Tensor& x); + +Tensor scale(const Tensor& x, + const Scalar& scale, + float bias, + bool bias_after_scale); +``` + +说明如下: + +- 尽可能与Python API属性保持一致,函数名,参数列表,返回值均保持一致,使用户在Python与C++的切换中,几乎没有新增的学习成本(如果必须不一致,可以增加新的C++ API,Python已有的运算类API与C++ API一一对应) + +**这个新建的C++ API体系目前主要用于什么场景?** + +1. 作为自定义算子开发时可调用的C++ API,提升易用性 + - 例如现在用户在自定义算子中初始化一个Tensor需要循环遍历Tensor数据并赋值,有API之后可以直接调用`paddle::ones`,`paddle::fill`这些API +2. 作为新动态图的基础调用单元 + - 新动态图会以API作为调度计算单元,不会再调用Op体系,以提升调度性能 +3. 作为反向Op复用前向Op进行开发的基础 + - 现在反向op kernel需要单独实现,在API体系成型后,希望可以通过复用前向API完成反向Op实现 + +#### 2.3.3.2 C++ API自动生成 + +**为什么要自动生成C++ API?** + + - C++ API的实现代码在形式上相对固定,理论上可以采用自动生成的方式来实现 + - 使用代码自动生成可以有效降低C++ API的开发成本,且方便修改和维护 + +**如何自动生成C++ API?** + + C++ API的自动生成是通过解析Yaml配置文件来进行生成的,Yaml配置文件分为: + + - 前向API配置文件(`python/paddle/utils/code_gen/api.yaml`,解析后生成代码文件为`paddle/phi/api/include/api.h`和`paddle/phi/api/lib/api.cc`) + - 反向API配置文件(`python/paddle/utils/code_gen/backward.yaml`,解析后生成的代码文件为`paddle/phi/api/backward/backward_api.h`和`paddle/phi/api/lib/backward_api.cc`)。 + +C++ API生成的关键在于Yaml文件的配置,以matmul为例,其前向和反向的配置文件如下: + +``` +# 前向API配置 +- api : matmul + args : (Tensor x, Tensor y, bool transpose_x=false, bool transpose_y=false) + output : Tensor + infer_meta : + func : MatmulInferMeta + kernel : + func : matmul + backward : matmul_grad + +# 反向API配置 +- backward_api : matmul_grad + forward : matmul (Tensor x, Tensor y, bool transpose_x, bool transpose_y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad, bool transpose_x=false, bool transpose_y=false) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : MatmulGradInferMeta + kernel : + func : matmul_grad +``` + +其中各项配置参数含义: + +- api:函数名称,需与Phi Kernel注册的函数名相同 +- args:函数参数,顺序和数据类型必须与Phi Kernel同名函数完全一致,Attributes类型必须排在Tensor类型之后。 +- output:输出类型,如果有多个输出间用逗号(“,”) 分隔开。可以在类型后用"()"选择性标记每个输入的名字(如`Tensor(out)`),如果没有标记则默认处理为out0, out1, … +- infer_meta:计算返回Tensor的维度与类型(详见InferMeta函数介绍) + - func为调用的InferMeta函数,默认输入为args项的所有参数和output参数,其中的Tensor类型变量会自动替换为MetaTensor。 +- kernel:API调用的具体Kernel函数 + - func:kernel函数的注册名(REGISTER使用的name,非函数名),默认输入为args项的所有参数和output参数 +- backward:(可选)对应的反向函数名称,没有则生成纯前向API。 + +Yaml解析脚本将根据上述配置项自动生成对应的C++ API,生成的代码中会完成包括Kernel自动选择、Tensor转换、Data Transform、InferMeta以及Kernel调用等相关处理逻辑,具体可参考生成的`api.cc`内代码。 + +由于C++ API数量较多,且有着各种各样的形式与功能,为此在Yaml配置机制上也提供了很多更为灵活的配置项,如`invoke`等。 + +### 2.3.4 Kernel形式、注册及管理 + +#### 2.3.4.1 Kernel形式 + +> 本节要点: +> 1. Kernel函数形式要点: +> (1)数据类型T,与DeviceContext(简写为Context)作为模板参数; +> (2)Context作为Kernel第一个参数; +> (3)返回值Tensor以指针形式作为输入参数,Kernel本身返回值为void + +这一层是具体的Kernel层,这一层实现的函数,会作为Kernel注册到框架中,供框架统一查找和调度。 + +目前我们期望这一层的形式如下,以`scale`为例: + +``` +template +void Scale(const Context& dev_ctx, + const DenseTensor& x, + float scale, + float bias, + bool bias_after_scale, + DenseTensor* out) { + ... +} +``` + +说明如下: + +- 不同设备的kernel要有不同的函数实现,函数名采用**驼峰式命名**,除了首字母大写之外,命名尽可能和API函数名保持一致,同一个计算的函数命名保持一致,通过不同文件或者目录管理不同设备的函数 +- 一般有两个模板参数,T和Context(尽可能),用于运行时决定数据类型和设备类型 + - 按照我们目前的体系,绝大多数的Kernel都是按照**特化DeviceContext和数据类型**这种方式缩减代码的,这与原先OpKernel的形式一致性比较强 + - 形式要统一,将来如果Kernel层也作为细粒度API暴露的话,易用性有保障 +- 函数输入参数规定: + - 以具体的DeviceContext作为第一个输入参数,如CPUContext,CUDAContext,用于满足运行时需要特定上下文信息的需求,如多stream需要传stream进来 + - 暂不支持一个Kernel传入多个DeviceContext参数,目前认为这样的需求不太合理 + - 参数列表和API保持一致,如果有其他的特殊信息需要传入Kernel,通过Context传递 + - 随后是所有的输入Tensor与输入Attribute,均以const &方式传入,POD类型直接以值传入 + - 输入的Tensor是具体的Tensor类型,如DenseTensor或SelectedRows,不是对外接口API那个Tensor + - 最后是函数的返回值Tensor,以指针形式传入 + - 为了满足灵活性,让kernel可以适配更多的场景,后续会允许声明灵活类型的输入、输出和参数,参考tfrt的Argument(输入), Attribute,(属性) Return(输出)等模板,以适配非Tensor的输入输出,以及Tensor类的Attribute,让机制更加灵活 +- 函数内部实现按需决定: + - 短期: + - 将现有OpKernel内实现,迁移到具体的设备Kernel内 + - 将存在设备公用的OpKernel实现抽离为函数,由多个设备Kernel共同调用 + - 长期: + - 复杂Kernel直接调用基础Kernel完成计算,鼓励Kernel复用,简化代码 + +> FAQ: + +>- 为什么需要使用模板参数?为什么不和torch一样,没有模板参数? + - 运行时数据类型T和设备Device的选择是在kernel选择时必要的操作,各个框架都是一样的 + - torch在写法上避免了用模板实现kernel选择,但实际上采用了全局kernel map查找的选择方式,这种方式的开销是比较重的,一个kernel的执行过程中,可能存在多次kernel map的查找 + - 基本流程如下图: + - ![图片](http://bos.bj.bce-internal.sdns.baidu.com/agroup-bos-bj/bj-2aafdb051eaea7120bdf9604eb738029dcd3162a) + - 这种方式存在的性能问题已经被torch自身认识到,所以torch也在做算子库重构,但是积重难返,他们重构也并未对此问题从根本上解决,只是减少了一些redispatch的层数,我们不能一味模仿竞品自身都认为有问题的设计 +- 为什么第一个参数需要是DeviceContext?为什么不能不传? + - phi kernel要求是纯函数形式,即函数内使用的变量均通过参数传入,或者在函数内部创建,不允许在函数内部使用全局单例,为了适配多样的kernel需求,像DeviceContext这种存储上下文信息的参数是必要的 +- 为什么需要两个模板参数? + - 为了方便设备无关kernel的复用,假如我们要实现一个傅里叶变换fft kernel,假设这个kernel能够使用基础kernel组合得出, + +#### 2.3.4.3 Kernel实现 + +> 本节要点: +> 1. Kernel专注表达数学算法,不掺杂调度逻辑 +> 2. Kernel足够细粒度,边界清晰,没有可选参数,便于复用 + +现有Kernel因为Op参数过于复杂,引入了调度逻辑,例如 + +- 通过`use_cudnn`判断是否执行cudnn分支,在新的Tensor计算库中,使用cudnn计算是单独的Kernel + +为了降低成本,Phi Kernel实现会尽可能继承原先的OpKernel实现,大部分Kernel的实现仅需要将原先OpKernel中取Input,Output的逻辑移除,并且修改一些关键方法即可,以sign为例: + +原先sign OpKernel: + +``` +template +class SignKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& context) const { + auto* out = context.Output("Out"); + auto* in = context.Input("X"); + out->mutable_data(in->place()); + + auto eigen_out = framework::EigenVector::Flatten(*out); + auto eigen_in = framework::EigenVector::Flatten(*in); + auto& place = + *context.template device_context().eigen_device(); + EigenSign, T>::Eval(place, eigen_out, + eigen_in); + } +}; +``` + +迁移后的phi sign kernel: + +``` +template +void SignKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + dev_ctx.template Alloc(out); + auto eigen_out = phi::EigenVector::Flatten(*out); + auto eigen_x = phi::EigenVector::Flatten(x); + + auto& dev = *dev_ctx.eigen_device(); + paddle::operators::EigenSign, T>::Eval( + dev, eigen_out, eigen_x); +} +``` + +除了kernel形式从结构体变为函数式之外,还有两处主要变化: + +1. 由于参数都是具体的输入,所以不需要再到context里取输入输出,相关代码移除 +2. phi kernel中要求输出tensor的内存申请统一使用`ctx.Alloc`或者`ctx.HostAlloc`方法,不能再使用原先的`mutable_data`申请内存 + +> FAQ +> 1. 为什么mutable_data要替换成ctx.Alloc? +> 答:因为原先的mutable_data方法中调用的全局方法memory::AllocShared内部使用了全局单例进行内存分配,这不符合前面说过的纯函数设计原则,从业务需求上来讲,kernel里面如果使用单例确定显存分配的方式,在推理的多线程环境中,不能线程不能指定不同的存储分配方式。 + + +#### 2.3.4.4 Kernel注册 + +> 本节要点: +> 1. Kernel将自身全部关键信息暴露给框架,记录其输入、输出和属性的信息,否则将导致框架调度与 Kernel 计算之间界限不清 + +现有 fluid Kernel 注册时仅记录了 Kernel 的 place,layout,dtype,输入输出等统一由 ExecutionContext管理,没有相应的信息记录,现在kernel要改成函数式,每一个函数的输入输出和属性都是明确的,我们希望在这里记录每一个输入输出的信息,也是为了兼容paddle-lite的调度。 + +同时,我们需要简化Kernel注册的写法,现有的写法都不够简洁: + +1. fluid的Kernel注册写法,有不少冗余信息,以scale为例,可以看到每个kernel除了最后的data type,前面函数名和DeviceContext特化的信息都是冗余的 + + ``` + REGISTER_OP_CPU_KERNEL( + scale, ops::ScaleKernel, + ops::ScaleKernel, + ops::ScaleKernel, + ops::ScaleKernel, + ops::ScaleKernel, + ops::ScaleKernel, + ops::ScaleKernel, + ops::ScaleKernel); + ``` + +2. Paddle-Lite的kernel注册写法,为每一个Kernel都声明了输入输出信息,但由于每个数据类型的kernel都是不同的,也会造成写法上的冗余,如下代码可以看到,除了data type,其他的信息也基本是冗余的 + + ``` + #ifdef LITE_BUILD_EXTRA + using scale_int32_f = + paddle::lite::kernels::arm::ScaleCompute; + REGISTER_LITE_KERNEL(scale, kARM, kFloat, kNCHW, scale_int32_f, int32) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .Finalize(); + + using scale_int64_f = + paddle::lite::kernels::arm::ScaleCompute; + REGISTER_LITE_KERNEL(scale, kARM, kFloat, kNCHW, scale_int64_f, int64) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .Finalize(); + #endif // LITE_BUILD_EXTRA + + #ifdef ENABLE_ARM_FP16 + using scale_float16 = + paddle::lite::kernels::arm::ScaleCompute; + REGISTER_LITE_KERNEL(scale, kARM, kFP16, kNCHW, scale_float16, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))}) + .Finalize(); + + #endif // ENABLE_ARM_FP16 + + using scale_float = + paddle::lite::kernels::arm::ScaleCompute; + REGISTER_LITE_KERNEL(scale, kARM, kFloat, kNCHW, scale_float, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .Finalize(); + + using scale_int32 = + paddle::lite::kernels::arm::ScaleCompute; + REGISTER_LITE_KERNEL(scale, kARM, kInt32, kNCHW, scale_int32, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .Finalize(); + + using scale_int64 = + paddle::lite::kernels::arm::ScaleCompute; + REGISTER_LITE_KERNEL(scale, kARM, kInt64, kNCHW, scale_int64, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .Finalize(); + ``` + +因此,本次设计,不希望继续保持目前这种冗余的写法,希望kernel注册方法足够简洁,同时还能够灵活地满足Kernel输入输出信息配置的需求。 + +对于这个问题,关键点在于kernel需要指定自己的device,layout和dtype作为它自己的key信息,而大部分kernel输入输出Tensor的device,layout和dtype和kernel自身是一致的,对于这类kernel,我们可以按照kernel的信息自动生成填充每个输入输出的信息,不需要通过BindInput,BindOutput声明;我们只需要针对与kernel信息不一致的输入输出去配置特殊信息即可。 + +新实现的kernel注册形式如下: + +``` +PT_REGISTER_KERNEL("sign", CPU, NCHW, pt::Sign, float, double) {} + +PT_REGISTER_KERNEL("mean", CPU, NCHW, pt::Mean, float, double) {} + +PT_REGISTER_KERNEL("scale", CPU, NCHW, pt::Scale, float, double, bfloat16, + uint8_t, int8_t, int16_t, int, int64_t) {} + +PT_REGISTER_KERNEL("scale.host", CPU, NCHW, pt::ScaleHost, float, double, bfloat16, + uint8_t, int8_t, int16_t, int, int64_t) { + kernel->InputAt(1).SetBackend(pt::Backend::kCPU); +} +``` + +说明如下: + +- 去除了之前注册方法中大量的冗余信息,可以一行代码完成8个数据类型的scale kernel注册,同时根据kernel信息默认记录每个输入输出的信息 +- 对于有`ScaleTensor`这种动态attr输入的kernel,可以在函数体重配置具体参数的Backend,Layout和Dtype信息;没有此类需求的,函数体为空即可 + +此外,在`PT_REGISTER_KERNEL`宏内,通过模板推导,对Kernel函数的函数形式了归一化处理。 + +输入参数列表各异的kernel统一被归一化为如下形式,从而能够以统一的函数指针存储到下文中的Kernel数据结构中: + +``` +using KernelFn = void (*)(KernelContext* ctx); +``` + +通过在Kernel函数外包裹`PT_KERNEL`进行自动推导 + +``` +#define PT_KERNEL(...) \ + ::pt::KernelImpl::Compute +``` + +此外,目前仅实现了基本的模板适配,后续我们会根据需求添加,以让在整体机制更加灵活,适用范围更广。 + +#### 2.3.4.4 Kernel管理 + +> 本节要点: +> 1. 介绍目前Kernel管理组件的设计 + +对于新形式Kernel的管理,目前设计类图如下: + +![kernel-design.png](./images/kernel-design.png) + +说明如下: + +- `KernelFactory`作为管理Kernel的全局单例数据结构,和fluid的OpKernelMap类似,两级map,第一层根据name找到Kernel集合,第二层根据KernelKey找到具体的Kernel +- `KernelKey`和原先的OpKernelType类似,但将palce和library_type字段合二为一称之为Backend,因为原先的LibraryType是一个有局限的枚举类,原本就和place是强相关的,拆分反而增加了理解成本 +- `Kernel`相比原先的OpKernel持有了更多信息,除了执行时的Function,还持有了具体参数的信息,即`KernelArgsDef`,对于Tensor类输入输出,保存了Tensor类型信息、Device,数据类型、数据布局,对于Attribute类输入输出,保存了类型信息 + + +### 2.3.5 Kernel自动化编译及依赖分析 + +> 本节要点: +> 1. 介绍kernel的自动化编译设计 +> 2. 介绍kernel的自动化依赖分析设计 + +原OpKernel迁移至phi之后,在编译上需要创建新的编译target,目前phi也设计了相应的自动化编译方式,使大家在迁移之后,尽可能不需要关注编译相关的内容。 + +#### 2.3.5.1 Kernel自动化编译 + +目前按照相应的规范迁移kernel之后,重新执行cmake,cmake会自动根据新增kernel的文件名,创建相应的编译对象,相关的逻辑在`paddle/phi/kernels/CMakeLists.txt` + +``` +set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils) +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function) +# remove this dep after removing fluid deps on tensor creation +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils) +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) + +# auto build kernel targets by cmake +register_kernels(EXCLUDES math_kernel DEPS ${COMMON_KERNEL_DEPS}) + +set(MATH_KERNEL_DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel phi_transpose_cpu) +if(WITH_GPU OR WITH_ROCM) + set(MATH_KERNEL_DEPS ${MATH_KERNEL_DEPS} phi_transpose_gpu) +endif() +kernel_library(math_kernel DEPS ${MATH_KERNEL_DEPS}) +``` + +1. 首先,定义kernel的公共依赖集合`COMMON_KERNEL_DEPS`,有较多kernel依赖的组件均可以放置到该集合中 +2. 通过函数`register_kernels`,自动解析kernels目录下的`***_kernel.h`文件,自动创建对应的target +3. 如果某个kernel有自己独特的依赖,可以将其标记在`register_kernels`的EXCLUDES集合中,跳过对其的自动生成,后面再使用`kernel_library`函数,生成对应的kernel target,`kernel_library`也是根据文件名自动生成编译target的 + +具体`register_kernels`和`kernel_library`如果扫描文件并生成编译对象,可以参考`camke/phi.cmake`中的函数实现,此处不展开介绍了 + +#### 2.3.5.2 Kernel依赖自动化分析 + +phi kernel整体改为了函数式,本意就是让kernel之间可以更加方便地复用,但是复用kernel会引入kernel之间的编译依赖关系,比如A Kernel调用了B Kernel,那么在编译上,A Kernel需要DEPS B Kernel,这样的编译依赖声明对于开发者来讲同样是非常繁琐的,因此我们也设计了对应的自动化解析方式,具体如下: + +在编译A Kernel时,我们会分析A Kernel相关的`.h`和`.cc/cu`文件中include声明,如果A Kernel include了 B Kernel的头文件声明,我们会自动为A Kernel添加B Kernel target的依赖,例如: + +dot_kernel.h有`#include "paddle/phi/kernels/empty_kernel.h"`,那么在编译时,dot_kernel会自动依赖empty_kernel,这一过程也是在`register_kernels`和`kernel_library`函数中实现的,可以参考`camke/phi.cmake`中的函数实现。 + +因此,开发时如果需要进行Kernel复用,正确include相应头文件即可。 + +> 注意:这里只有kernel间的复用是会自动解析的,如果某个kernel依赖了某个function或者functor,仍然是需要手动声明依赖的,phi的设计鼓励kernel之间的复用,因为kernel本身也成为function了,因此像之前那种调用function的方式长期来讲是基本可以被淘汰掉的,只需要尽可能将function实现为kernel即可 + +### 2.3.6 InferMeta(Shape)抽象整合 + +原先fluid Op的InferShape和OpKernel一样,存在重复开发的问题,因为不同Op的InferShape函数无法复用,因此即使不同Op的InferShape逻辑一样或者类似,也都是重写一遍,本次phi的重构也需要解决此问题。 + +我们将InferShape同样改写为函数式,支持不同的Op可以调用同一个InferShape函数,提升易用性,降低维护成本。 + +> FAQ: +> 1. 为什么要叫InferMeta,而不是继续叫InferShape? +> 答:InferMeta的Meta来源于DenseTensor中的meta成员,在phi中,一个op有两大组件,InferMeta和Kernel。这里InferMeta覆盖了InferShape的功能,但又不限于InferShape,除了对dims和lod的推断,InferMeta中也会承担dtype和layout的推断,这一点和原先是不一样的。 + +#### 2.3.6.1 InferMeta相关设计 + +首先InferMeta也为函数式,几个示例如下: + +``` +void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out) { + out->share_meta(x); +} + +void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) { + out->set_dims(x.dims()); + out->set_dtype(out_dtype); + out->set_layout(x.layout()); +} + +void CreateLikeInferMeta(const MetaTensor& x, + DataType dtype, + DataLayout layout, + MetaTensor* out) { + out->set_dims(x.dims()); + out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype); + out->set_layout(layout == DataLayout::UNDEFINED ? x.layout() : layout); +} + +void ConcatInferMeta(const std::vector& x, + const Scalar& axis_scalar, + MetaTensor* out, + MetaConfig config = MetaConfig()); +``` + +特征介绍如下: + +1. 函数命名为`[FunctionDesc|OpName]InferMeta` +2. 函数形式与Kernel类似,函数参数依次为MetaTensor输入,Attribute,MetaTensor输出,返回值为空,原则上InferMeta函数与其对应Kernel函数的参数列表是一一对应的,差别仅为Tensor参数类型,InferMeta函数的Tensor参数为MetaTensor,Kernel函数的Tensor参数为DenseTensor,SparseTensor等 +3. 对于一些需要区分编译期与执行期的InferMeta函数,在末尾添加MetaConfig参数,config中有is_runtime的flag成员,之所以用结构体,是为了便于后续扩展其他flag成员。 + +这里使用MetaTensor是为了屏蔽多种Tensor类型,以及兼容原先fluid的VarDesc及Variable,一个op对应一个InferMeta函数即可,如果不对类型进行屏蔽,本身InferMeta函数就会因为输入类型不同而重复开发多份。 + +其中MetaTensor的基础设计如下: + +``` +class MetaTensor { + public: + explicit MetaTensor(TensorBase* tensor) : tensor_(tensor) {} + + MetaTensor() = default; + MetaTensor(const MetaTensor&) = default; + MetaTensor(MetaTensor&&) = default; + MetaTensor& operator=(const MetaTensor&) = delete; + MetaTensor& operator=(MetaTensor&&) = delete; + + virtual ~MetaTensor() = default; + + virtual int64_t numel() const; + virtual DDim dims() const; + virtual DataType dtype() const; + virtual DataLayout layout() const; + virtual void set_dims(const DDim& dims); + virtual void set_dtype(DataType dtype); + virtual void set_layout(DataLayout layout); + virtual void share_lod(const MetaTensor& meta_tensor); + + private: + const LoD& lod() const; + TensorBase* tensor_; +}; +``` + +基类的MetaTensor中有一个TensorBase的指针成员,因此在phi中可以兼容DenseTensor,SelectedRows,SparseTensor等多种类型。 + +#### 2.3.6.2 InferMeta注册管理 + +为了支持InferMeta函数的统一调用,InferMeta函数也进行了统一的注册管理。 + +首先也需要类似前述Kernel形式归一化的`PT_KERTNEL`工具宏,命名为`PT_INFER_META`,并实现类似KernelContext的InferMetaContext(实现不展开了,仅放置部分片段,详见`phi/core/infermeta_utils.h`) + +``` +class InferMetaContext { + public: + InferMetaContext() = default; + ... +}; + +#define PT_INFER_META(...) \ + ::phi::InferMetaFnImpl::Call + +template +struct InferMetaFnImpl; + +template +struct InferMetaFnImpl { + static void Call(InferMetaContext* ctx) { + InferMetaFnCallHelper>::template Call<0, 0, 0>(ctx); + } + + private: + template + struct InferMetaFnCallHelper; + + ... +}; +``` + +然后设计对应的单例类用来存储MetaFn + +``` +class MetaFnFactory { + public: + static MetaFnFactory& Instance(); + + bool Contains(const std::string& kernel_name_prefix) const { + return meta_fn_map_.count(kernel_name_prefix) > 0; + } + + void Insert(std::string kernel_name_prefix, InferMetaFn infer_meta_fn) { + PADDLE_ENFORCE_NE( + Contains(kernel_name_prefix), + true, + phi::errors::AlreadyExists( + "`%s`'s Series Kernel's InferMetaFn has been registered.", + kernel_name_prefix)); + meta_fn_map_.insert( + {std::move(kernel_name_prefix), std::move(infer_meta_fn)}); + } + + const InferMetaFn& Get(const std::string& kernel_name_prefix) const { + auto it = meta_fn_map_.find(kernel_name_prefix); + PADDLE_ENFORCE_NE( + it, + meta_fn_map_.end(), + phi::errors::NotFound( + "`%s`'s Series Kernel's InferMetaFn is not registered.", + kernel_name_prefix)); + return it->second; + } + + private: + MetaFnFactory() = default; + + /** + * [ Why use kernel name prefix? ] + * + * one op -> a matrix of kernels + * + * such as, scale op, it may correspond to the following kernels: + * + * - scale, scale_sr, scale_dnnl + * - scale_raw, scale_raw_sr, scale_raw_dnnl + * + * All the kernels in each row correspond to the same infershape function, + * the number of kernel arguments in the same row is the same, and only + * the tensor types in the arguments are different. + */ + paddle::flat_hash_map meta_fn_map_; + + DISABLE_COPY_AND_ASSIGN(MetaFnFactory); +}; +``` + +封装对应的注册宏,用于InferMeta的注册,注册写法示例如下: + +``` +PT_REGISTER_INFER_META_FN(sign, phi::UnchangedInferMeta); +``` + +对于InferMeta的注册,一般不需要开发者手写,我们通过yaml中api name和InferMeta的映射关系,自动生成对应的注册条目。 + +#### 2.3.6.3 InferMeta兼容fluid InferShape + +在fluid中,继承MetaTensor实现CompatMetaTensor,重写对应的成员方法,以使InferMeta函数兼容VarDesc和Variable的输入,以dims为例,CompatMetaTensor的dims实现为: + +``` +class CompatMetaTensor : public phi::MetaTensor { + public: + CompatMetaTensor(InferShapeVarPtr var, bool is_runtime) + : var_(std::move(var)), is_runtime_(is_runtime) {} + + CompatMetaTensor() = default; + CompatMetaTensor(const CompatMetaTensor&) = default; + CompatMetaTensor(CompatMetaTensor&&) = default; + CompatMetaTensor& operator=(const CompatMetaTensor&) = delete; + CompatMetaTensor& operator=(CompatMetaTensor&&) = delete; + + ... + + DDim dims() const override { + if (is_runtime_) { + auto* var = BOOST_GET_CONST(Variable*, var_); + if (var->IsType()) { + return var->Get().dims(); + } else if (var->IsType()) { + return var->Get().dims(); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently, only can get dims from DenseTensor or SelectedRows.")); + } + } else { + auto* var = BOOST_GET_CONST(VarDesc*, var_); + return make_ddim(var->GetShape()); + } + } + ... +}; +``` + +然后,为了将函数式的InferMeta嫁接回fluid的Op体系上,需要将函数式的InferMeta归一化为functor形式。 + +通过前面介绍的PT_INFER_META宏归一化函数形式,然后将`PT_INFER_META(***InferMeta)`包装到一个functor中,functor中先将InferShapeContext转换为InferMetaContext,再调用相应InferMeta函数,通过一个宏统一管理代码 + +``` +#define DELCARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \ + struct functor_name : public paddle::framework::InferShapeBase { \ + void operator()( \ + paddle::framework::InferShapeContext* ctx) const override { \ + auto infer_meta_context = \ + paddle::framework::BuildInferMetaContext(ctx, #op_type); \ + fn(&infer_meta_context); \ + } \ + } +``` + +这其中的关键函数是`BuildInferMetaContext`,这个函数会从InferShapeContext中,将InferMeta函数需要的参数取出,统一放到InferMetaContext中并返回,InferMeta需要的参数列表通过ArgumentMapping函数获取(详细在2.4 动静态图执行兼容适配中介绍)。 + +然后将该functor在Op注册时维护到相应OpInfo中即可,同时删除原先Op的InferShape实现,示例如下 + +``` +// 原先实现 +class SignOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "sign"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "sign"); + + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker, + ops::SignGradMaker, + ops::SignGradMaker); + +// 升级后实现 +class SignOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; +}; + +DELCARE_INFER_SHAPE_FUNCTOR( + sign, SignInferShapeFunctor, PT_INFER_META(phi::UnchangedInferMetaNew)); +REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker, + ops::SignGradMaker, + ops::SignGradMaker, + SignInferShapeFunctor); + +``` + +至此,实现原Op的InferShape函数迁移至phi InferMeta之后,可以重新注册回fluid中被调用,从而实现InferShape的函数化复用与全局统一。 + +## 2.4 动静态图执行兼容适配 + +> 本节要点: +> 1. 新形式Kernel如何在现有静态图和动态图体系中调用,难点在于解决多参数Op到少参数Kernel的匹配问题 + +### 2.4.1 ArgumentMapping体系设计 + +由于新形式Kernel参数列表与Python API对齐,和原先的OpMaker中注册的参数列表存在差异,导致新形式Kernel在原先fluid体系中调用时会很难匹配 + +例如conv2d op,它的OpMaker中注册了4个Input,1个Output,26个Attribute,而conv2d的Python API一共只有8个参数(不算name,3个Tensor输入,5个Attribute输入) + +运行时,调用新Kernel之前,需要将Kernel需要的参数从OpMaker注册的参数中选出来,再传给新Kernel使用。 + +对于一些原先就编写规范的算子,它的OpMaker参数和Python api参数本就是对应的,这种标准的情况,不存在需要选参数的需求,对于这部分算子,根据OpProto中输入输出属性的注册顺序,跳过标记为Extra和Quant的成员,可以解决一部分Op和Kernel的参数匹配问题;然而对于一些不太规范,或者说是fluid时代遗留的算子,比如像conv,就需要这样的映射函数,且这个映射函数根据op不同,可能存在非常复杂的判断逻辑,因此现阶段没有办法可以自动化处理。 + +为此,目前设计了ArgumentMapping函数映射的体系,在phi/ops/compat目录下,实现相应的映射函数并注册,然后在phi kernel执行适配时,会调用对应的ArgumentMapping函数,得到phi kernel需要的参数,例如scale op的映射函数如下: + +``` +/** + * Note [ Why does the ArgumentMapping function need to be so complicated? ] + * + * In order to meet the requirements of infrt, the function used to match Op + * and Kernel parameters, need to be placed in phi as a compatible component, + * and does not depend on fluid. + * + * Because infrt not only needs to dynamically call this argument mapping + * function at runtime, but also needs to statically declare all possible + * results of the function before running without any information. + * + * The infrt declare like: + * + * def PDKEL_Reshape_to_CPU : Pat< + * (PD_ReshapeOp $x, $shape_tensor, $shape_attr), // OpMaker arguements + * (PDKEL_ReshapeKernelAttr $x, fn($shape_attr)>; // Kernel arguments + * def PDKEL_Reshape_to_CPU : Pat< + * (PD_ReshapeOp $x, $shape_tensor, $shape_attr), + * (PDKEL_ReshapeKernelAttr $x, fn($shape_tensor)>; + * + * Therefore, we need to write out each result of the argument mapping function, + * like `KernelSignature("full", {}, {"ShapeTensor", "value"}, {"Out"})`, it + * cannot contains variable, only can contains const char* string. + * + * Infrt will parse all results before running for the generation of the above + * static declare, which leads to some functions being written in a long way, + * and the complicated ones may have hundreds of lines, which has certain side + * effects on the programming experience. + */ +KernelSignature ScaleOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("X")) { + if (ctx.HasInput("ScaleTensor")) { + return KernelSignature( + "scale", {"X"}, {"ScaleTensor", "bias", "bias_after_scale"}, {"Out"}); + } else { + return KernelSignature( + "scale", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"}); + } + } else if (ctx.IsSelectedRowsInput("X")) { + if (ctx.HasInput("ScaleTensor")) { + return KernelSignature("scale_sr", + {"X"}, + {"ScaleTensor", "bias", "bias_after_scale"}, + {"Out"}); + } else { + return KernelSignature( + "scale_sr", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"}); + } + } else { + return KernelSignature("unregistered", {}, {}, {}); + } +} +``` + +其中的ArgumentMappingContext基本接口设计如下: + +``` +// TODO(chenweihang): Add more methods if needed in future +class ArgumentMappingContext { + public: + virtual ~ArgumentMappingContext() = default; + + virtual bool HasInput(const std::string& name) const = 0; + virtual bool HasOutput(const std::string& name) const = 0; + virtual bool HasAttr(const std::string& name) const = 0; + + // now we can't use Attribute here, it will cause phi relay on + // boost::variant and BlockDesc + virtual paddle::any Attr(const std::string& name) const = 0; + + virtual size_t InputSize(const std::string& name) const = 0; + virtual size_t OutputSize(const std::string& name) const = 0; + + virtual bool IsDenseTensorInput(const std::string& name) const = 0; + virtual bool IsSelectedRowsInput(const std::string& name) const = 0; + + virtual bool IsDenseTensorOutput(const std::string& name) const = 0; + virtual bool IsSelectedRowsOutput(const std::string& name) const = 0; +}; +``` + +无论ScaleOpArgumentMapping是在fluid中使用,还是在infrt中使用,只要能够构造出特定框架的ArgumentMappingContext,即可获得对应的参数映射关系。 + +**1)对fluid的适配** + +在fluid中,该函数需要同时在静态图和动态图中使用,比较直接的思路是,直接通过ExecutionContext构造ArgumentMappingContext,然后在op执行时调用,例如 + +``` +// TODO(chenweihang): split impl based OpProto or Dygraph if needed +class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { + public: + ExecutionArgumentMappingContext(const ExecutionContext& ctx) : ctx_(ctx) {} + + bool HasInput(const std::string& name) const override { + return ctx_.HasInput(name); + } + + bool HasOutput(const std::string& name) const override { + return ctx_.HasOutput(name); + } + + bool HasAttr(const std::string& name) const override { + return ctx_.HasAttr(name); + } + + size_t InputSize(const std::string& name) const override { + return ctx_.InputSize(name); + } + + size_t OutputSize(const std::string& name) const override { + return ctx_.OutputSize(name); + } + + bool IsDenseTensorInput(const std::string& name) const override { + return ctx_.InputVar(name)->IsType() || + ctx_.InputVar(name)->IsType(); + } + + bool IsSelectedRowsInput(const std::string& name) const override { + return ctx_.InputVar(name)->IsType(); + } + + private: + const ExecutionContext& ctx_; +}; +``` + +**2)对infrt的适配** + +若在infrt中,infrt只有训练存储的推理program,也就是只有Proto这一层的信息,那么可以通过Proto信息去构造对应的Context使用,**proto中的信息目前在支持参数匹配上是完备的**,例如 + +``` +class ProtoArgumentMappingContext : public phi::ArgumentMappingContext { + public: + ProtoArgumentMappingContext(proto::OpProto* op, proto::BlockDesc* block) : op_(op), block_(block) {} + + bool HasInput(const std::string& name) const override { + // simple search + for (int i = 0; i < proto_->input_size(); ++i) { + auto& in = proto_->inputs()[i]; + if (in.name() == name) { + return true; + } + } + return false; + } + + bool HasOutput(const std::string& name) const override { + // simple search + for (int i = 0; i < proto_->output_size(); ++i) { + auto& out = proto_->outputs()[i]; + if (out.name() == name) { + return true; + } + } + return false; + } + + bool HasAttr(const std::string& name) const override { + // simple search + for (int i = 0; i < proto_->attrs_size(); ++i) { + auto& attr = proto_->attrs()[i]; + if (attr.name() == name) { + return true; + } + } + return false; + } + + size_t InputSize(const std::string& name) const override { + return proto_->input_size(); + } + + size_t OutputSize(const std::string& name) const override { + return proto_->output_size(); + } + + bool IsDenseTensorInput(const std::string& name) const override { + for (int i = 0; i < block_.vars_size(); ++i) { + auto& var = block_.vars()[i]; + if (var.name() == name) { + if (var.type() == proto::VarType::LOD_TENSOR) { + return true; + } + } + } + // TODO(chenweihang): throw error when cannot found + return false; + } + + bool IsSelectedRowsInput(const std::string& name) const override { + for (int i = 0; i < block_.vars_size(); ++i) { + auto& var = block_.vars()[i]; + if (var.name() == name) { + if (var.type() == proto::VarType::SELECTED_ROWS) { + return true; + } + } + } + // TODO(chenweihang): throw error when cannot found + return false; + } + + private: + proto::OpProto op_*; + proto::BlockDesc block_*; +}; +``` + +### 2.4.2 phi Kernel兼容调度执行 + +目前phi kernel可以兼容地在老Executor,ParallelExecutor,动态图的Tracer,Engine,推理的Predictor,以及新执行器InterpreterCore等在执行体系中被调度执行。 + +具体地,在动静态图调用OpKernel之前,判断对于当前计算,比如`scale`是否有新形式的Kernel已经注册,如果已经注册了,则调用新形式的Kernel去执行,如果没找到合适的Kernel,仍然执行之前已有的OpKernel。 + +``` + if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) { + if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) { + pt_kernel_signature_.reset(new KernelSignature( + std::move(GetExpectedPhiKernelArgs(exe_ctx)))); + VLOG(6) << *pt_kernel_signature_.get(); + + kernel_type_.reset( + new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx)))); + dev_ctx = pool.Get(kernel_type_->place_); + + pt_kernel_name = pt_kernel_signature_->name; + pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); + pt_kernel_.reset( + new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( + pt_kernel_name, pt_kernel_key))); + + if (pt_kernel_->IsValid()) { + VLOG(6) << "Static mode ChoosePhiKernel - kernel name: " + << pt_kernel_name << " | kernel key: " << pt_kernel_key + << " | kernel: " << *pt_kernel_; + } else { + VLOG(6) << "Static mode ChoosePhiKernel - kernel `" << pt_kernel_name + << "` not found."; + } + } + if (pt_kernel_->IsValid()) { + run_phi_kernel_ = true; + } else { + auto& all_op_kernels = AllOpKernels(); + auto kernels_iter = all_op_kernels.find(type_); + if (kernels_iter == all_op_kernels.end() || + kernels_iter->second.find(*kernel_type_.get()) == + kernels_iter->second.end() +#ifdef PADDLE_WITH_XPU + || + paddle::platform::is_xpu_place(kernel_type_->place_) && // NOLINT + !paddle::platform::is_xpu_support_op( + type_, *kernel_type_.get()) // NOLINT + || paddle::platform::is_in_xpu_black_list(type_) +#endif + ) { + auto pt_cpu_kernel_key = + FallBackToCpu(*kernel_type_.get(), pt_kernel_key, *this); + pt_kernel_.reset( + new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( + pt_kernel_name, pt_cpu_kernel_key))); + + dev_ctx = pool.Get(platform::CPUPlace()); + if (pt_kernel_->IsValid()) { + VLOG(6) << "Static mode PrepareImpl - kernel name: " << pt_kernel_name + << " | kernel key: " << pt_cpu_kernel_key + << " | kernel: " << *pt_kernel_; + run_phi_kernel_ = true; + } + } + } + } + if (!run_phi_kernel_) { + if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) { + ChooseKernel(exe_ctx); + dev_ctx = pool.Get(kernel_type_->place_); + } + } + +... + + if (run_phi_kernel_) { + phi::KernelContext pt_kernel_context; + // Do data transform before building KernelContext + // TODO(zhiqiu): support TransferInplaceVarsBack + PreparePhiData(exec_scope, *pt_kernel_, *pt_kernel_signature_, + runtime_ctx); + BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context); + (*pt_kernel_)(&pt_kernel_context); + } else { + (*kernel_func_)( + ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); + } +``` + +对于phi kernel的执行,有两个关键函数 + +**GetExpectedPhiKernelArgs** + +- 在调用phi kernel时,要完成多属性到少属性的匹配,这里就需要调用前述的ArgumentMapping函数,从而得到phi kernel的参数列表,GetExpectedPhiKernelArgs实现如下: + +``` +KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( + const ExecutionContext& ctx) const { + ExecutionArgumentMappingContext arg_mapping_ctx(ctx); + return phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type())( + arg_mapping_ctx); +} +``` + +**BuildPhiKernelContext** + +- 要调用phi kernel,需要准备phi kernel需要的Context,PhiKernelContext和原先的RuntimeContext及ExecutionContext不同之处在于,PhiKernelContext中是以SmallVector存储输入输出及属性,访问效率上要比原先的map高一些 +- PhiKernelContext中不存储输入输出及属性的name,要求这几项顺次存储,和kernel的参数列表顺序一致 + +Phi KernelContext的基本设计如下: + +``` +/** + * Note: KernelContext doesn't manage the life of DeviceContext and Tensor + * + * Note: KernelContext does not couple the concept of framework, + * its constructor can only take the members it needs as parameters, + * not Scope, RuntimeContext, etc. as parameters + */ +class KernelContext { + public: + KernelContext() = default; + explicit KernelContext(DeviceContext* dev_ctx) : dev_ctx_(dev_ctx) {} + + void SetDeviceContext(DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } + + template + const CtxType& GetDeviceContext() const { + return static_cast(*dev_ctx_); + } + + void EmplaceBackInput(const TensorBase* input); + + void EmplaceBackInputWithoutSetRange(const TensorBase* input); + + void EmplaceBackInputs(paddle::small_vector inputs); + + void EmplaceBackOutput(TensorBase* output); + + void EmplaceBackOutputWithoutSetRange(TensorBase* output); + + void EmplaceBackOutputs(paddle::small_vector outputs); + + void EmplaceBackAttr(paddle::any attr); + + const std::pair& InputRangeAt(size_t idx) const; + + const std::pair& OutputRangeAt(size_t idx) const; + + void AssignInputRange(std::pair&& range, size_t idx); + + void AssignOutputRange(std::pair&& range, size_t idx); + + template + const TensorType& InputAt(size_t idx) const { + return static_cast(*(inputs_.at(idx))); + } + + template + paddle::optional OptionalInputAt(size_t idx) const { + const auto& input = inputs_.at(idx); + return input ? paddle::optional{static_cast< + const TensorType&>(*input)} + : paddle::optional{paddle::none}; + } + + template + std::vector MoveInputsBetween(size_t start, size_t end) { + std::vector v; + for (size_t i = start; i < end; ++i) { + auto t = static_cast(inputs_.at(i)); + v.emplace_back(*t); + inputs_[i] = nullptr; + } + return v; + } + + template + TensorType* MutableOutputAt(size_t idx) { + return static_cast(outputs_.at(idx)); + } + + template + std::vector MutableOutputBetween(size_t start, size_t end) { + std::vector v; + for (size_t i = start; i < end; ++i) { + v.emplace_back(static_cast(outputs_.at(i))); + } + return v; + } + + template + AttrType AttrAt(size_t idx) const { + try { + return paddle::any_cast(attrs_.at(idx)); + } catch (paddle::bad_any_cast&) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Attribute cast error in Op Kernel Context.")); + } + } + + size_t InputsSize() const { return inputs_.size(); } + size_t OutputsSize() const { return outputs_.size(); } + size_t AttrsSize() const { return attrs_.size(); } + + private: + DeviceContext* dev_ctx_; + + paddle::small_vector inputs_; + paddle::small_vector outputs_; + paddle::small_vector attrs_; + + paddle::small_vector> input_range_; + paddle::small_vector> output_range_; +}; +``` + +## 2.5 产品思考及后续规划 + +目前,phi算子库仍然处在Kernel体系的建设阶段,Kernel尚未完全迁移,且仍然存在诸多完善点,但将来phi算子库会更好地将“算子”的概念纳入进来,这还需要比较长的时间和比较大的人力投入。最后,从“产品”的角度介绍一下phi后续对于算子开发范式的规划,也能够让开发者更容易理解 “为什么要做算子库重构?” 这件事。 + +### 2.5.2 原算子开发范式 + +我们应该如何描述“框架算子”这个概念? + +万变不离其宗: + +- 我们如何描述一个人?1. 他叫什么,长什么样;2. 他的工作、兴趣、爱好、特长、品质等 +- 我们如何描述一个物品?1. 它叫什么,长什么样;2. 它的用途和功能是什么 + - 比如一个杯子:1. 它叫水杯,长这样;2. 它用来盛水的 + +简单说,我们描述一个对象,可以采用两段式结构: + +1. 它的名字,样子或者说形态 +2. 它的功能,特征,及细节 + +算子同样可以按照这个原则去类比: + +1. 这个算子叫什么,有哪些参数,返回值是什么(即Op) +2. 这个算子在不同场景,不同设备中,怎么执行,怎么计算(即Kernel) + +如果我们**能分得清楚1和2的边界,并守住这个边界**,我们设计就能够趋于简练。 + +这是什么意思?就是说如果我们一定要用两段式来介绍一个对象,那么哪部分应该在第一段?哪部分应该在第二段?得有个逻辑清晰的认知。例如,我们用两段式介绍一个人: + +- 方式1:1. 他叫张三;2. 他在百度工作,他喜欢唱歌、爬山、骑行,他待人真诚,认真负责 +- 方式2:1. 他叫张三,他喜欢唱歌;2. 他在百度工作,他喜欢爬山、骑行,他待人真诚,认真负责 + +哪种分段方式更好一些呢?答案是显然的,方式2的两段中有同样形式的内容,逻辑不清。 + +为什么用这种方式来类比?因为我们的算子开发面临的场景我觉得是一样的,市面上现有的框架,对于算子的定义,都围绕着**“1. 算子描述,2. 算子执行”**的两段式进行设计。 + +顺着这个思路,我从“语文、逻辑和信息认知”的角度介绍一下我对fluid算子开发现状的理解,如果把现在的算子体系当做一篇以**算子**为题目的“小学作文”来看的话,拿高分有点困难。 + +**(1)"生僻词"比较多** + +fluid的Op开发概念对于新人来讲,可能是一种看“文言文”的感觉,似懂非懂。 + +如果我要描述一个“运算”,我需要讲清楚它叫什么,输入输出有哪些,这就够了,例如一个乘法运算,`叫multiply,输入x,y,得到out`,在这一点上,Python API是足够简练的。 + +那么现在我们的内部算子要怎么描述呢?要实现以下类和函数,可以参考 [mul_op](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/mul_op.cc) : + +``` +# 前向op +OperatorWithKernel +- InferShape +- GetExpectedKernelType +- GetKernelTypeForVar +OpMaker +# 反向op +GradOp +GradOpMaker +# 类型推导 +VarTypeInference +# 显存优化 +InplaceOpInference +NoNeedBufferVarsInference +# Op注册 +REGISTER_OPERATOR +# Op版本管理 +REGISTER_OP_VERSION +``` + +直观看说实话会有点困惑,新人可能会有什么疑问呢? + +- Operator可以理解,是算子,OpMaker也可以理解,是告诉这个算子怎么生成,但为什么要两个呢?OpMaker都已经告诉你要怎么生成这个Op了,为什么还需要再写一个Operator?Maker不应该把Operator make出来吗? +- 除了这俩,剩下的都看不懂。。。这些都什么意思?什么时候用?我新开发这个算子哪个需要写,哪个不需要写? +- 算了,我短时间内搞不懂,算子也着急,找一个类似的算子,copy一份,它写了什么我照着挪过来,能跑对就行。。。(这可能是大部分新人开发算子的心态) + +**(2)重复的“修饰”比较多** + +其实每个算子真正不同的信息,就那么几个词,剩余的东西都是模板一样的存在,人脑处理信息是有成本的,去区分差异是需要思考的,从产品的角度来讲,直接把不同的地方告诉用户,让用户只关注这些差异是最高效的。 + +以 [MulOpMaker](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/paddle/fluid/operators/mul_op.cc#L72) 和 [DotOpMaker](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/paddle/fluid/operators/dot_op.cc#L35) 的实现为例,我们可以发现以下几点: + +1. 除了Op名字,输入、输出和参数命名,两段结构极其类似?为什么我们不能把这几个空抠出来让开发者直接填空? +2. 输入、输出和参数后面的大段描述属于**重复建设**,并且现阶段没有用处,因为Python端已经写过一遍了,并且写得更规范,更清楚,C++端这里的参数注释没有人把关,质量参差不齐。 + +再看Operator的GetExpectedKernelType([mul](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/paddle/fluid/operators/mul_op.cc#L41)和[dot](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/paddle/fluid/operators/dot_op.cc#L28)):,也一样,都是根据x选择kernel,那为什么还要让开发者写其他的内容呢?直接做个填空,填x是不是就行了。 + +我们开发Op的时候,这些组件多少都存在这样的问题,这增加了大家工作量和理解成本。 + +**(3)相同的“段落”写了好多遍** + +这里主要指OpKernel的开发,我们现在的OpKernel之间可复用性较差,比如已经有了mul和add的Kernel,我们现在要新增一个fc算子(由mul和add)组成,我们得去mul和add的kernel中把代码拷贝一份过来使用,而不能直接调用mul和add的kernel。 + +这是我们建设phi初期要解决的问题,并且我从周围新人的口中已经听到过多次这样的反馈: + +- 我开发新算子,需要一个broadcast操作,我得去另一个算子里copy过来,还得先调通,copy的时候可能没copy全,或者应用场景稍有不同,这都需要额外的时间 +- 实现gumbol-softmax算子,因为softmax是其中的子运算,我得先把softmax的kernel实现copy过来 + +**(4)“描述”本身有二义性分段** + +说回开始的两段式结构,”1.算子描述;2.算子执行“,分这两段是必要的,也是业界普遍的做法,我们不需要再分第三段了,但paddle目前存在第三段,算子描述分了两段进行,并且这两段还不一致,即PythonaAPI和Op。 + +API和Op都是对算子运行行为的概要描述,本质上只是同一段内容的不同展现形式,比如[python dot API](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/python/paddle/tensor/linalg.py#L993)和[DotOpMaker](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/paddle/fluid/operators/dot_op.cc#L35),就是告诉别人“它叫什么,参数都是什么”。 + + +咱们对同一个东西的描述,分两个地方写,还写得不一样,这是很令人费解的。就好像你介绍一个人,在学校你说他叫”张三“,在公司你说他叫”张三丰“,有相像之处,但又不是一个意思。 + +对于一个算子,它的输入、输出应该在各个场景下都是一致的,如果不一致,那本质上就不是一个算子。 + +比如,conv2d的api和op,[Python conv2d API](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/python/paddle/nn/functional/conv.py#L416),很简单,8个输入参数;但是对应的[conv2d op](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/paddle/fluid/operators/conv_op.cc#L259),有**32个**输入参数,让人摸不着头脑。 + +开发者也会很困惑,我开发op的时候,API和Op不是一个东西吗,我应该写得一样呢?还是不一样? + +推理之前为什么要做**算子增强推全**,就是op的参数太多了,但API的参数很少,这两者本来是介绍一个东西,却差别如此之大,所以需要发动全员,在op的某些参数上标记AsExtra,就声明这个参数可能是多余的。 + +当然我们演变到如此田地,有一定历史原因: + +1. Op输入输出参数规范限制差,留的口子太大,可以天马行空地写; +2. 2.0 API对外层Python API的形态做了大范围规整,但是Op层保持不变,是导致目前同一段描述差异变大的一个主要原因。 + +对于这个问题的解决,我们的方向是很明确的,就是**Op层描述向API层靠拢,因为API层的定义是经过2.0 API项目仔细设计过的**。 + +### 2.5.1 新算子开发范式:完形填空 + 拼积木 + +phi期望的Op开发方式:**“完形填空”式算子描述实现 + “堆积木”式算子执行实现** + +**Op实现:** + +需要写的内容如下: + +``` +# 配置文件 api.yaml +- api : add + args : (const Tensor& x, const Tensor& y) + output : Tensor + infer_meta : + func : ElementwiseInferMeta + param : [x, y, -1] + kernel : + func : add + param : [x, y, -1] +``` + +以填空为主要方式,名字,输入、输出、输出的增强推断,用什么Kernel。 + +原先需要写得大段重复代码,全部通过”代码自动生成“的手段去实现,开发者不用再关注。 + +主要思想:仅让开发者关注最小的差异化信息集合,填空指定信息。 + +这里Op配置时,要求和Python端参数命名等完全一致,做到上下层描述一致,不给开发者留空间在op层自由发挥,导致想加什么加什么的随意行为。如果需要给op加参数,API也要一起更新,这首先需要通过不兼容升级评审。 + +**Kernel实现:** + +``` +template +Fc(const Context& dev_ctx, const Tensor& x, const Tensor& w, const Tensor& b, Tensor* out) { + phi::add(phi::mul(x, w), b, out); +} + +PT_REGISTE_KERNEL("fc", Fc, ...) +``` + +mul和add操作的拼接,代码量很少,再加一个注册声明。 + +整个Op+Kernel的开发也就十几行代码,在去除所有冗余信息,仅保留差异化信息上,这种方式已经是没有什么精简空间了。 diff --git a/docs/design/phi/images/kernel-design.png b/docs/design/phi/images/kernel-design.png new file mode 100644 index 0000000000000000000000000000000000000000..c58f01dddd26d5c0375fade8c9b4acba5966bcd1 GIT binary patch literal 102777 zcmeFXcTiMaw>}7n3Pu!jAczR&&cR6MoO7G#oO5ne5OYKXR1_6+!h|`3iinsoVMfe~ zuV6q>)Y<)h_g3AiJ2QVy)vszS>C?T>K0B|qp66L-F_%s0+G#+iu&}VMG%ArF7S=8$ zEG*)0R0mKp^YDoo;1cfQQ}AK?Zwz`L7B=9NnnP7-Rl3@U2AS2LI{PQ(9bB{U_CTcP{Ubm&N@j^MMp2^bmI&3{UD2Djd${7vt*uqwdt|J??n#%0i1M}yIann8=9ZWYpnAaM~;tCfT{`GQ^(Eg6QV zL(n|9QiDO6v3`{XiV3h>I+zIIvoR>4Wxxtl`T$R6Vv$UEw;#@y^hc{!$MKX7AXedLScgb00raFP-ui8S|FFv;ZzczBiCTa zU}R)GG+0jGGDYFj6QyzjkqZ&Au`ry?j1__UY#-jv(0N=u zI+F>32aNsS!4seml#2lOYQ-Y27ww`4ES%7SDB(uXkrm0&Dz#{gMlb$%dU$dG;{xR1US zVuL)h)Nh2+&|V$ei%=s8c9xSwVTib(4$0|aAqdYBQ_YxP;`(=r@(lA zjz{Fhq3Ku^K}Fyo^+pRPVDe}%I24B^V6oX&5>M;Hb5J&@S}ntamBo`4FeXvuG@@aA zg^7zp3ZxniN1$^#gGz(Ni4q!c5JONx#`~>4fdEE=@k85zm%-pVJQS~E>hye*jo>oS zHE0;tK+xi0R2skh?C(Z5;0O^bfSp@1VV(DGLSSQ%BOT1 zP$;^W;HP;QD2>j`CtDaO1WxI7^L-pT1%a@;*nTXNK($GrL6|~D;Rz5VD%a_RtNmy- z9VexNeZN2njzYc5;KA<7%72wL}`FUie$WIaxOg^&;A(oTKdJ_!+7h-LEe-Ned znYG|KHrXzqli*G_L*VtQ{bUkJgs@rkI1+_Ow2A0`3qu6;c|2A(UTf1S$S#GQ17~?r zVi<}HItg_ol%j1Iyw58lIX!q4&nzSm@kk!a<-*f(R*f7@A&XRexl&4m2qY{&gCfP5 zi6$FL=wb8J5HuMyu3(axKC{E*)JYvIxR^tc(oqVNTI`j09B`fr$B?kN7Lq}4q#9T# ztqV%0)xlfjAXac1?Lzwa-{$qRB`~v z)7g;{v<*hWGcja>TIHh&1Z)QIjARWQkA+znAkc2KcjYN>LAtas0EtXkuKBm(uHZkeC z0NETKEBvjiIvWkzm)zZX=IEvuR*b6_(;~8}ulJRb~^y6ageY2u1RgOq$6c z!`uC0D-kbo+qr74hi+9%u>_)6gtB5V3WL#0Aq%kri3Ka=dZco)1<8@Pa73IBhv$Q( zS5a`pfY3z3lQn1Ax^=o95R_ryHw0w`*CL8 ziaBl?hDyL&R8}NPqwpfYh`k&oTNWBQR1`!~6h;#r!WC#uRG$|_IniQs$re2dqQG`mO;hK&ml30{Gg z=z+NMiDF7Ua6TB+A5S4 zp{H_%CWg$YW;=wSJCn_X@+(!SKu}@!Lgj=2ljLLon}&Kr$42O+QFSQ2--0m-VFG|s zfcK}6{NO6m+iYkl%c14-RCH+2Y*(;#6a`#urAw(OjGHU42QsV+T_5Qk?nN40ReTR;7}%> z=GM4f6f&AD5YwDiEy+bExp*Xxk^nQSb(DZgb?Q?htuD~5}; zOHB%0=zv39QkRC!^q>VSoyl#-S*1FgG&EE;M20fh=pqx|z~%VoGu5#3W127 z3aXC<;mIjzK0q=?4VodL;+R5{-3q}|VO)$w#neHGVl~{0BmkU)mx^3CuEwhuXlyVa zM$C6;Wq6bjst}WfkbuNwms%hQtHlg+YPAkWfUc1{)B+OO<8l*W43gOE((tgr{W4ty zy3Y?osjLtT#AcSOAUX_-rPt%EcnlHdl{!c?g%wOnC1fl73=$Y01S%vaJ1SWUP0k|uzE2S%$kf7*whXm$^w_F zaa_GeY_l2l7KP4a(5bm*2^j;n9v|8fpAya2h+!HOT*3a=yU55qzmw*$^LTbI%>{v& zKt(#l<>zu80z!(djv zk03KhwRpK3VMXvXT7(H88M8d(D3BtUiL7yeG2jslwT))g`JEzw7pZI%-09I!Nm@i8 z6#w}?D^rOOB0`2haWl9&lo#hka1kUnM~YCJoM=DKK{v56cCVJD()wizmEVFk!mtn+ zO$?KARU(g#;0&3fM(mbZ9R@m`!SM$re7c9OQ4&Hus3dloPUBGuy;h~tuGhHWPKyAq zG*TpJfGO}CwSq{odO-JB2%mxup)@O$3bTMK1`ntUxc)bL`G24?Xr}?9HtLu>i_mP> zXenwV9)z6`b^2da1{D4`l~IDeS+)?FF{v41iPx$|0|cUwV5J_oTPhWJ)o6!?0VDA7 zD8E4~R68MBFAT5K8rXJ{NonOOIS_|Ui8t9;7>v)y#RWY+Ji-e_W6d-&6ypbw3y+lX zojff}PQlCV9+(14#8@;Mg}}$vAY^ouUO?nI@KUXf8ZeObc7jKQ69y$F4nZaI^CWr= zg+nsy1Qc7qM3EvKYNSi=;iB9qxsGV|5IlY&NvSbw83Y=W#d9+MMZ{W~+~!3RHA=Wa z;+622atY06k|J<4E>ujydZjq3MI|GF1x4CDd^FKQf@!RD4HYR@kyT6-m+4TE>^7`l z!G-&r7%WmA)POTbX4x$SyqOiCQqXRzQECj}H3qE}U;`G(Z&c~%L_M0UV_Ac68VnQA zK#U|ZK>$?+t#lk(B;g8JHY1W<= zoES2KA*L8GECiKLQD6ivh|YpT`D82r>GU2RN#Yh6Xl5k9oP2|UjB)W1cp*hWv=J;2 zjM+_rDg$^KRDyHqNgNpj1V1ug&9j998i}Dp;z(o|0YbI9xdtbX95Q7V6HnC(EM^?W z!?eh3NG}yZ6$A}Fzr*1Gkc>=_L%}D`;FJf1GLtfdNa6GV14S^IAbthPr1SA%X0jSW zGwJwFB1LUtK!QFpP2{w|5dgcRG$FD}Fj>qp69^GxHB>I;8c2AV&0rxCXfhs+hM?*V z2$IP!FzNg#4I%{oMOrb7DDwc6jg~3c5GN)O5QhA;AE`IV5eNYXufP!)bOHg(vs#cO z7*K2kHj=>*@N%IzInknEaRd++TPRS;dmIhGN5EC z;7TR1W;7$_*u4WP`b?8|-2OB)d^4Mh(1B+*Yp{X{hf^YS( zSwg6oOmR9T7BNfB*P-lc2VNi$5+DS(24#m+5g>kZXm}CI3_8(bVS1{G6vV0|Q~}TH z5^@}Hg5T=GQSl}mATJIO9*IJn0%srwtS$zyW3mS7bh)Ki7h3Ce2bgptTHz)H;0U5d zAcR^>GBgK}05R5UK!<#!oFnlW{2VITuM*3ZSUnYr_9InJj7G^eNw8EGfh^>@eFl6$ zO!W|abf?1VQ2POt!{ZqePEbLD1{_$OK_@JscLt>(u~z!MkxfI5Z#+K^O)N zKx0ruVUsmRJ%q?pdxK~(p6U-Oj9j8e#lToeBqW2z)_bfXFfgTuOk=o}I2YLg!Ps3S z1XYG|(P1PZ0}Iehz=)T~{bo6wNu;BlQk5?x2+5otqDyV0(G4O7ooIz)?ItXZMP=$` zN~#@V#mm$LFG??i3v2{|9w(qFEDRr=slyB8NfgQ0Y!*;3aZE%KoNWvRM156GeveIhCzbrf!{j`7EJb#AfIvYz;=uhXwgyZKWP24K53kU^1gyMHbqL zav8%yCqUsMy9DeI#sMWt%p5ep1!N6FDnf|;IPevYGbvE^ko>|#2^k^A$_)4_ubX>d` zPICakA?T(0h-QP{0>Bmv0nySl4w0CEW`ukwhl`eJQJ4TyBj*Oha5UG0A~J$(ss@8$ zy7hQ!2uQF=Sd_(1XZ!qWh%w}Dna%(gkHrgV1f)nLRH`*inkUbNgy_iLme=#d?0JlrXEVr&o;REbq86jBG6yV4mt*cud1ic{EJ zE*OGuBhdX69+QMtg*+$5L6(3D0R3s__}kMH$49W}iM$~Bl6`*2qxS)w`52E zDKQj=Frjg}v>YUttpyP|qy!^KSf7;Xl_MZxhZisqrV5Vt5GT}0b=Chg#iiLZTFbXRt4TqX9~P6jmIdy7gKkgyNz&Y@E;%qE&J< zlZ-bz%?>6CBb5Lr?UYg^QjEr~3`#=+1QSJ(OWkaXnJl!j#WFM84|k&|5Ub1UmN}t3 zs*H@)kl08koGM{k=^I2tZP1(psUAiIfx9N- zQaap=mmpluf7u1j<##&iP&Ab(S@xH?wt!d1##`l00^N z@`KM4zn98fr-r6Q46U0KvG^r1cLfr5(w4S@7y-)0#YB-ETiTJlFT4@^F8@=2I{M}G ztf}3op?lBp=@(W#S1W(vpTVZx$cv(7e7yN1Dw8}RCDAR<~Hnn)ke@0rlt(ZS`eK1*VsPl+tsKV{6P^}k2$8~;6b_23{a zr+K;N`j^&K!%ghM!{(*LMg6Xf=I-DO8nd%Fs)vw!^A zI%wUUIcLXj!!766?E4&LxOHak6mxyT8?c5-c3I7co{;fi1rBG_ZfoWht-X6Fe640U zd-^&57}w66yl;i2yStDpsJZtF``+A;`=Z^cs@>0X=%>2KwoQ8wz6z_XQ?|t#* z-?P6JFN=F5*gx&$(DqB(QboplYfd!vq^_?-zG~PVy9y)gWO#h?K>lv&H1U|iB-hU6 zc_Y`{7`cA@kGNqczx4T8xj)|kuX^#^d-cb2%9`O{@9%PuyETj{i=UjDH0)Zzyv2E~ z&n`46mJVmntZhYXh&;0vY(?MkH`nZBJL)r@J*bE-?R8-Kxt03cs_J8N-zM{4rFN?P zs^e@=zd!VYEzr47-rQf_J)}ycBO}`^i~9Y3a7B`6B0z#rxjOEvY{D zbjrOl=qB!F0^|HB+TEtAwlwBuOYJQFf^YA4RWb^PojvmD31`x*pF{Tz6bUkKu1d-& zn&74W&iJ*tv-8)R@3w|VhhIvPt&Kk-qw}c8gwUbNimKAjuaMDwMy>rkdQkkawe=MX zH*oII%9jG8@E`g*Gw`8T%DK0*SVdC#B-52)1s?94-yf1|cE29W->DdBcywxMiF}-- zvg_xOgEPAIS=4?}UgY9c)m5~w$&BVFFvGXBWcHDXqf+nT#p(+sud{3R^j@mHvg=A` z;esucg~LiSe~4wr6lH@Z|9p(DuXr?&CHLf>T0Uj(GRl&>?~{3}Uez2ob;degWnG(b za@vgz6Ne{taDN^7RF275)$sJB?2&I^DreJ#KYL0(t{HKuKq@=pJ^rd{TAG&AfAqzU z?t&frPSbCknSbt(>Cd%j*a=_9_h|E~>ifdc8z_)X?6V76o^5Gzg!Xx6_^p{e8~MP* zr+>iA*o3;Y`18InR#5_jcg{Kbe5G5>nf=r!KbO;ouw!}r2-Sk-KOMTzk1koWXIHlW zYQnVGl$g#ZzxK@~xrSZMinA0X%sDtG;hT@l2 zNce9#YT5DM4~~}lp9y6LgFAapUVS}yllG6>@%njYT=gdO4B9U4X3^^H^1?IxQgia| z>8G2}-lu0>f=(LNi1|wj#L1SJX?u?iTHU4a@UQ>ZS$yo;^nFi(Fk(w>&g*-Bmb~t) zqkq0}eYMyR}+AmEPDS_`FwpT3cJ!Sulvs(@Nvjm+BurUljYdxhO5`h*2t9lAk)VcF<(oPk%BsfxUnF z$<*JckLNBvEKbc0$2mXzNaFA3v&woREQFb_<6XO9`}TGGlfUhX;L1fcC zESA3Jy}WZ-BmUxPkANGK)x6G#@2i1cTAVX(qc$&VO)Q8M&W=And;=SH^XuyRNLvHG z!&dvTzbJqAxQWwX+Uv=h&Klvv{2ON%R#ze!_s-?f4om1K8hd09Wj{MJd;7d@M9Ta< zM>QW#Y<7OE6t#ZJ8CYt)H{^6hH0thg%T&Ui2ZmAem#)cOG5W)j&&6b!}&S`QVqzC4lU$NAmw;mot*eh<4xb&V<{9o|{>JLCAh#}{t3#ilHM z%jVU0?D~2}szvp6A~;p&-n!(>2yFuJwH?O(EbbaAaevBSComi1AAQh#-Ij6s@3Z|! zzgErI^UCn%RqmGut%pDOs>Xa9gdv|j-w8MKa@^c)hV|ADw!F(-Ry8I4yi)JU7@D!a zZA<%asZU3U&h4o#88rn$Slr==;-!wU{ewq@?SJlV#TU_qKN>Liz`p7OM@|-B>-3AB zP&D9pLkYcV6ysbadrHZ(Yu~?JDdX>SLOs(@@Bh7h_TlE5y(&TSu_HS^<=v*2kNR{0 z+WnuWxow59yDvU4kFWmBXgZs-v=n#>%3#i6TKUzGr#RVemggs!)Q7;Y2J-48`d$6( z1O1MDF0tiYSYLXpFKlq_=NDAMf)sJdiJn8$RW^2JbHm~hr!V~2F8%Shz7_smeBjpR zi@Uzs8!rC(F?9BYiaE0%zWOK`n{2C$?3pAU^w3*3=y}cSMK785ZkHBE7g1xWpP zNB!0nJ0H&JdBX8&3GiK<(l=E}t(^;2)+OR2gWbAALKwmIvp}4i^$>Kh_UXl8M|0{c zwyoGM`yWVR$o~KsKSSsuCpiO@{eS-OyNO>HO_#`SB;u+ZYsTC#Os=q{SLVx z(S64hF^e!f*5`YBCH!+U11H_@R}I40(|RV;U)31;j+P6J_x2>u@6>zMt1Ed|H;nfU zwLUpvyw-IfiV6z1W&ZtBvy&75^o^p|dT95p2+ICEa^d*5II#YwLnr&ekLi|^N$6>; zqPAb3%f9^n@%;4l*&t}GWfzPncwFguoz}U}DDHU>QU2~*aC_dB{nJiwqvj6TKS=(4 zqQ5acncQdW^ikhGK3Eg@`ewJo_s^5z=d;@jF1M6o!ebZZUCB!sso)|uvOwg(Cn6?6ptZepKVTqKtxOU96Q}4!g-&*@*7cXhrC+7}080(E? zmt6puQ`5Qw2N>k4*?$0^Ht)Tat;2<=z%r{v5(ZL=4ine|KDsaI;k z07nt4HS{3^i3p40qbDH_mMRLYnA#epr_?lwDNIDN3lSz)m@CEtnAa(sKRg z_TZd#gY!CKt`1bBLVa$wkqb!e( zG!49bzYX%X@7U>8y9GO6HNLnwqBBBqAu{<)$>Y<$&%N&a`22k2UqVs(zJiewcHE?y zyJij^?=^b{tT->1opf%m9#(%i%o;55L7F&4S(iDJe{Y*K7L2)Z{khCZfBrfBIlt^h zd`Z_M_xDX1xq0r*-I=V|6x`9)J3BemZ?CUMui0}F+o9`SUF@IBZ%ZMIs+y}$hX&+@kw_n)WVS%);KzrML{Jp1nUpp)zm&@sP&b4m!;0XB#s z7njZc9XY30M8oJQ$EfdXwu>4UeHjwTSUwOfKEG=i8`3fVK|;T?vwxog z+fAJK>1J*cC2JJ4V`Sc?CB46n?^fSoOGI|W*R?RslAe-&v2%KCH4SEuWy$e^}LDH1RHtJ z&&?_bn=|(4wdhL1u`t!en~fbSlIL|>5f4ev`Z@4^_?yV{*x7CVhlf-P>lZh0)0%6Y zgW8rH<)3tFUe6WplMl}gTcM8`aoqZp|MKCw_y}#}k-T9n{v={~7oMtrF=&2@_U6WY zZDEN^R6XkOS;d)Ad^Kwnd{|@R{mBu5lB7jM$if7-_pE8sv1>);+Sio%= z_^OzMFx8I$@PU%ap2RAM<< zyNj@}+mhsUDAfJVr5SoDI>s&9EA|9Ko8Bv>U%SGsAf8?YMw9>dc7oj8+AnGn&M0zJ&qNUGk@d0a30vZu?iv+VXExPv9dt*lYt?=r#+!dRV zfTB!J;%2Af>>baiKaJ`4B6@sjOjOJ582gZ@^RY7!k*`2#6J>Yc9P1G!M@4>)+I8gG zmF|<v!c(5`uC$FoP(C0yg;^$?YQDj#TDQh?#3$fFUYYh(5=j*YyUiX&V+lZzyfqh6}3R^lAlv;MS-r}X{PdLy_mq20Xl0T-gXhL4Sq z9z2SiyePJ$D@&S%hIML;u1LH;EMjJq?&<>1{G-tWAMBsDcV6dg^}z*x3T?gTVsK8F z=U;>El~R?K&aPX0GYrO)a?+JoFUK&dn$u?`2^K62f7BuKuYbG|9w)5(y8A5m{7={X zLl8;+{TJuw?rR;ge*5CstF4=7M8!5fKVrz6QJAnII<#eRExrkt7RU8^*b&HjrvY|8 zu^w~z&it%>>;GBUJwEd9W`Jov?A=+i4uf^n&)btsXl-CC3-y&6_Ls>`e)2k(J!7*MA+#%D_+=SDG)rcF$>A4%whPok|F zN>{h0up8Zj&X0_cM;1qwLQ_V&vYp85akWdGBl2gCmX#9Q&#hXPzA;AgXlY~#x}y6a z4EJHz+__|M!j_QW9500^1?_MA2HGRGK_fUL?U!3x(89l;U5~s}f%-5d84$6*wtw~LqmkA=) zt;G*>UTv?(p+eXG2Q2PxufV%J@* znlbvvJXz*`$i`vpmDdj;XEbyy?UnMbV+Qr$Lf6#QU7;Fg>l|F z`XnZ&E{vne9ejRvhkS8j9ckj^<^jbMI?RET#x-{^MIGU8S`o}#G$?XLOj7Em+u8Av zD8F&pBWFHd5g$c2#GL4GH=!r@r@D27GUHNo_ON_ZV$@}=bJ_gsZK-H{){?AFE^_1Y zYq1f*K}&bG$)h`M?4z0yRXQ+g9`BT@_1KvFNjOAtV$%Hbhm-ppd^{bjWJApI)EPCA z(Ge4`b&a@C2EyFsbG_9wS0>MymUGQGM5K+dIzAG-C>mMa-8=SEv zA&w_+pPR+(gi-y*uN`;#+2~p3(QHWHGcnD?4LHH2EzN84I>B;x+==?9X$ zug3L%cWF&WK+INVg(S8U{UE1(UBdxAx&Pdi{bg|ekVHFS*2&iktow-3oCk*_7fVa7 zir>6?e`n{OUY%xbPAsm>6VEK0+dMHPRfHSf9|a|pCtMAi6H~i)dDAQL<0t!j9HkVc zPi~jZ&ZqtU`SHL9p}C6mCT#u3rVJKw^Q+FuoT`H@KgOl2lVYTa{tH;3qWs>H--&ZQ zk-5^)l+!*!`Dr}lNZ2*;n@F7b@`}N~VO^uRJzA93)z4iKnF)<|U_IA%uIT$F%8@;) z$1~=FwCSB!yvV4H;3-D45Bl0azqnb$PPnyFQ>y76IXH53^R7ryQSpSX*>~=@9YlBd zNEZ;_W<|ij^e;!Xb($rWv{2DIf&VZSQ;lRq~4Q|VBXz$0(OCP*t^Rlj6dTAUt z_WKy|zkWys z_K}j}eXYii1Ge&rgy_kND_=wOq+AB*$&H?` zTH_MymaHEGi>{ryth6(wY+H{E&cTJ@2RxUw>Z|3%cJs?#?1omBS7t=6j%xUAs$En3 zpf7qz-JGZk+8L2gTDlLM+2d|fG5%rq>WhuXHfIf}UKVq${rv_-4{3bMcw+mF(xkOr zV}~u9xO1kf%bSQEOM1pl$X&=KHAfbX&R9P>BPQ)-*Yro{dT0DPJmCgdtlYubeUe>u z#g?HDE>5n(Pd(`*75H)a<7%dUQk|(EAN=-?UX+-(U;#QE=a^%-Pi=o``j80+Z*4tUKW_KsGqu1ahebycv+C!4yf`v??fV^_4`p9}^ar`9eg3SI zdlCn>b7Y@In^*M6O2v`e$slJ3R-Z~>uS7N;UGRHAt*Tv*-De*Skaf-X;EuGWw4L41 z!(4f>N5102?Q8B0cjtBtVe-aqU>TpiNV!q7B+Pc|=cwq0$61iu&+@l(4m6O)M$~Sa zXYSE#ZIKhM3~cVM*tEQPOQ#vCx{|YJ*y9*0--0BJKN~0&DZ5?a{1VG^y$LT@UtQZ4O_WY zJq>%xv+?5FuS5WJdNxJ^P|M-IUCWNGi^c8iQQ8ia=Ec@|YF>5Q5g3{tQmBV{N*-RX zXFX_$+c66O;Tt`{U0L>oO<+M{fh;fqw?Q?Im4%-6|3Ce|R_)98_5T1Gnk?MK{nGFd zzBo2z#lxqpU!ko?sQdVU)^TN1FRZ-9bCS3ufyfoSU?$(^<7uUi`OZybaj~aLX1-jE)nxUE|N61{d_R@?(xUU6}Mu;B^F!0M5y7b`*RL%+mKzq`FO-~x{=OhSiAFf?ieb@c@*)$01(H_y>ZMvHqvtj!_ z+}nF~QIDj2F2>ui*8Slw?a$ATV?*M~vccoU?dvv8&RiYR0?*C*H56!0d%BN)_uyxV zaQCVsw&&BoyuJCi;2|K8$2z($6amE|C6TZ|cp2At#PV3&{Lj;oG1D(D>@HEYHC>K} z0Uhm1A$(`4?aj49Sn!*kwjQ)!j!&_gkJd@d0h? z^}?@-A-!Zt?EdfrL(^}v`f}>VPCKb{`F%GkT|G7*}9hrR2**j zn2ld3%Ljd0<}mH&KeLv~&KZ`kp{LGCVzXR0tG|=+$p;Zo2yC_pjGKzkggc zl-3jEo(g6}$KE*Y`MA17c9>ik{o*#zJdyENJ2+~9o`4)1RTPp~fl78S5iBFGzFpbp zn(a#ysPRAv13DY=#>RM# zcS@A7NBY(mS0-jl)C3^)ZkyMBV_j0(f`imT*!Jo%dd;-HW2SQFOo(WEeB3tDk)786 z{Dg`t|2#dt-TvYNP^UT{-5GH?3GQmF!<9js?*4sqeLa0q0(?&9#SUstV}5+ZyK9B; z5z@@Yx;dK;+ngztp$cyKoQJ+)scdNE5n6 z7^BX0?mc1~D(9ttQCQ^Ps0PXYskQnjiMp!$wePt`|{QzRGfzfN{1MeuCKpRgjKMv;Kvbv0=+P820n9`2^pE(P5{_;C> zQW~R#M<)NQH=*TheY%Z9NJ?u-Z@`!g+OMOxp#9Hx21JgyXr2xxsiFnPZ9XMx!xkH4btsCc0k8N8O&#O0O{@D$StGziV{SnL+F%Wta z_^*y5*hh|gCI-Jz+rNprzh=y|J}B7gaf!)^G3gs-Gkb%!Z9ui&sX6CoSs#*RPlWsQ1g<-KE#E%KP?< zAKUFbM135xIXO*A*h~^&cGEN_OJ4;%(4NiX&RlmzEU2tn$aJ>0Qp=>Iu>5I zVjC@Qgd@A>F5vfeXZ1ahRH^+@s<}F%Y0jvq2z}9={nN%D?LY8*X(pKNR*=syZt75d zjJsf7N%S*J`^{OSq9dc5*~RY8UkR0wvQAegro@b_>X~>rYXP=+tQtCM{rA1|j+&!U zuan9zKO3-Y{>1py&DpFNQWaCGW1APpUTgo$h_XE^A$Ds?KhQC9<;42Pv0c)PLIe_eJ65+Bw5uOC;HzN!10PPj#zwo-ET!=|TByf~>(egja9e-3nHoAbq$?V;L} z9V-#>eFoGYj%+x|{yA#U05$-%QUKx>c)F(ai3tZ&k4ahmT6q6=qP_A@k9;N04eU-0 zi#hS~%J_(YB`1<4|C5SKK7j{3OO4bEjuY&1iO`I{q~UWQ>Fl zV-WcpH;+BDuXD1ZmXf`|(c?)H!HH3O9FqpTS%`MwYKzZX)%M2b16dW(e-@xK zHZ#F0z3+8yB+-;Q{^I&eD-UH&T&Q@g|3CkjPZ!^iS>1pukx|8e^Sf4gr>dj zJo`}y%Z~xme)DDLP3g;j){jA5E(7VU_2|sRy;#uYOKEKd_WriLlqDiBQ`MroV3*9A8ycJZjAJGyV|+ zUr;YTi*1bPe&)r)QB_`C(S7~Tqi0i&nQ(|W6{!15)eeQ0$zkd+cSq1Dsc(N}s zZ2H@)eGn0;Wf$YNg4Etfoc80Ck8|&j%o!H`GN$qM$8C-?<%J=_RS_wH&>Tf2*JjZZ#a(#5$EurZ^;A; zY<%DjvXEft20Z3Jk^T*L2M-@V6l9*;0V6`Ey@~l5TgQ@Q?r6vb(0?kpk61I=&^5ZH z6SzI2>ByCaSF8u>Zrib7Hm3lE;QO{UtOPZ?0+R3~tdQ`I14zQ}!cxk=AqV`n{}JC9F*I#qw|Hn*xqU}@&^k2@2XPogjecQCFZ*xoC`F4vzhG6E8r@rZ!47O}+L<^I+ zCAd+PX#^6TVd&L`gI~`+bWgON0tuK^)~BZc3TO0M0^k%C zWX84um>j+ZYp=@^?w=})sTf~#fB&?R{bA~r>MN_R{Ldt{XFj%T{G_SJERTmJ)AQ}> zsM9w=wMA1;*m2j}b^W?aw0HHO@m}USgki{w3-j{;<=InqHu#-hy=LrK;Prb#sQIb6 zPhVo3>9@C#?oI8~d&E$X`CW6$^RfC&;A`6{kbqka;4I<#J;8PAA5-nY&HMjEd<0;x z8bsxGmrO&u&0*hP*zMNKaqkHy^drDY#evw?ucAj0`Bv0B#~o7RZ@{!d{9|5pEE#2e zlIr3co_|xO#9FPev{xi|_hWUL{%B97*Z99*bO`6@7 z3ZPPc?~!8a{SsNbZ<}*5m?O2!;HnV)+Oj$H_C$wa;Ejm%w|3*P=vvnu%d;~lzkPhb z0sQ5)jnXm;0O{?VO_OG4ZUVVMb#*=sFo7hHQe6Y~Xixf}?1+cU0DxP16lj!h4yi6z zUtKeHyZ6(hn$td8bdMqPf*+P)0R$s2E}YY@V|82WSJ954v)fu=zJI2k%=oc?69(ie zmonF*SRkjor}DwN%-z;v; zjT^>!_=UOPpzUcTFYUt`b?FKa7DI{gJr!NQKFFR{J8l0|Y@nraa~k5(iosL9xBmH^ zamo6+x&WZNUhkirJbZT-kMO7EW%qlw8%0DV%() zMFxUchw`|oJfnIvdqKCsdnyUsuFsMvEAl~Nb7TjQR1vw#X8k%=?|fO=75Aq7!}yj+ z*{CVak>9XKx>V1tpWJxq5d%=0k}1d8n)%J2pC7pY0I<9K^b+yDb;|~#*$>lm{to&s z`TpTvLBGqb>n4|XdKhDB8kb%kwm5F$sh$T{U#AZ8KW&Q-&)A>$iWGWVnG9a!RU&fi zle|3+zTL4jEH0Tkxvg<-*{mV1?6AJu2~o9QUtissEj==J#<^W}190eUcBBcs50D4m z2iR37y;_uTm4C`@D40dtFz#rl#sO$+$qiiHQZoTqldyPM@u2VtD<}{kegW1?FKYXC zhchB;#wqviPK`&kd;i8w3{n9XPX=!s(B0*$F^>9OHHdu03&W3whlkt$GOQe&U9A5v zk`&4mmrjaWFnTU0vN?v5`Vc>NbW|?x9p-#om$qf)eJ$YZftNL6ZHr>^37Zj}eka%~@67sGIwtaa zO#hg^AZN20BrSJe`@h(G>!_;QwqI0W(b6T-<)T4E0g>*IP6Je0kPr#!k`j?F5mZ1_ z8bm<4Q$UnPV$n!TBXF+yyzhI)-rqUj-eZq({ybwi#`EY}YsNkAdC%+m)x=FJIPH5a znG1SdAp&MLQqXS{t;*8>gxZ!Ml}N|vD$*k+a?MwsbsN7Xsf^~}Qcl<=DWC@IH0bd+Fo1(WlZ>GPw#1dF- z72GfT9yM>6#PF-!0(6Zk(TuUh_nUcVf^M#GOUaiF%BUByYO(TfT7m>NjkTeQKs1&o z;8zE&iI^qBy)xjV3>!T=-oIh9eRiRN%FO|T&tR!Tu9(I6t-j(_plO~_LETt9j>|Tv zH;!v%&-p^DYfLD8d#_OpW=)OK`BCmw3i&S@RA_<>wnEmV29^Sl{rF)}Id{oj#%i`i zaPmnGsE!}S*a6+3zTj}xQj1k9Xhr6mv}I@UWPgLfp2+x)}J zQbZ>*wK;&57$RKwOS9=g-eY4Ue?Sl5EgZ`>RTx{I;}jdSsb2&&a%Zj<6SQ$HMU z)V50yNSC}lfo6-BIOvThtW}>r^S&Y38c7a9FOJ3A-#*DA3#YO?AA|e7g-`?ss7zgJ zkU)4}h_Pyzq!;a3?vpu7faQUY0vJ~ZX_X}WQEFhk88ERtSlX5e!eTow_&Y*L7!tx| zgJos+u;}T#>F{LJ;cl4v&fgj#aIEvtN+bOr=3u2!hTLxCWgsOtL5Nj%5)X?Ur<~%$ z)OhP}X&=wOyrea4f%*m9C(G?U+gDpB#}1}WsET3#q&>UiAXcrTDC%;+f zv4MUpMG?$xz300bb5G&AdZn<;&GdWP-`?MRRX67P4mv!1dD1|B6s`JeR(_o%QQxB{ zG&>gOH)Suio4GuA&Z7SNpDvK0jP1rfvKB%g9No~9YltYUh{X|vRM&e&mMXZkVwhmg z{SCVsL@BDl_WY^%ij2lFHrxc{#6A8=7*6ifrdx1xTHUctGL3^$gV;>*4WWi*e{t_t zRf&T#=3kSpbn*H!2e<7d7%4iW9cH(2`EY83rABbAENkcd zJDi}_?T9PUO3ox$IQN6jNBf_&zhE$qpt5w)xoE`)m+u~Hw2pOX=;KSM_KsB(zrKh) zS`mjg_>#Zx(JjjVyrv1 zSARg8PyV7*R5V=u06)ip+djDeb8k9G9aJM!!byu87JSwU9uDo@n<963c22QU9vK=1Xfn*Pz4kSs9g+XKo>Q86$rz4*9u+=~>sysFZ_mus^B zyq0R@%?s(`r;LHML*$AaBP!&NMq_;CSi?m=G~ALs!FMm5pQ4E2cpkKhzvo~gw%Q!% zDZke^DddlmNicoO2{-N!tZ({4#|Es{C-Y6yM5%Jk5kd?vd9h zUCP;NsPcxpVQ`VGCYr&e3hEZ!y(%vzQz~XPW*E_V{652@Hg%Ay4EWW)& zlTYr>KCV$C&T;VcrWyA~u?ok^HH>%hKg%^eVAT&xzI0EUIz(Lkuz&N`Ea*5tC9|e{ zcPu*)VwLoKy@MeRkP>WTaz zyICevKQUI1c6+tJoZCS4N&9GgKoMfXGVJrN`c<>+RfQwr3&)m>uuk*g5(rgI47YLR zIo|3gQqc6>?MvZ+rqELi+P2zp|8yzB174)#+^tC-|MsddSu*Aa8k$rI`1o#g3J!d= zG5~-t;y0xl!_cq}??y>r!!|vT?f;+dK5#2ft26Po_FWlrm4AeQLR>L7P+#zG~T7PDn=WQWp`)5%#7{{KB&47V;jqKPfVpDLaNta^F4 zDAR5=iw$|`!k=j(*mB{UE@^0YebO4B4atU8*8DO7NR4a!&n9!L`aS=B5Cl zL`lFN2#?k-z=Z3s6OYC}p6}`ej68URfec3TpM#!a1G}4K-TNHajU#LSCqur}8&6BW z#X>p60d>4bv!*x-;dzmZgG`2vKlC=5PzdN?13s7tjU%pjHK$q}L9_+Ozqd$vkaVt< z&=?o_$&w?sQI6KzR(%2K4rrbdrCc*@Y08OLB~jcF0-E1j29; zU#s!|d+^=dLT9hln9#dE(+=yO?*{xvW^J^9H!>^ausM{YVt~l$fY!CHRN8ZM`kK;n zvIp-v|8v=d;QR#7BO~3PDm##^Ha}KwZ@!Ck6wAM=D|$IfB>#QY0IvXt0D^V^;T6yT z58!O5pSS-Wx?X){`~6)fNQ*53QeBAtz2SQfFdy3X$ZQbZd=5%ISx{j7%cFScGCB~< zU+1m7*+aryn63XzpLJy9q=y(I5R`x*UB(}wql8|v$QWTid-3l9Q)P9AO$J`UP$0h{ zx5Nh4G4X?LmO6+J5h*bxyJ9%&ek0YthF7&=7{=ist=&!t?z(OQn1g18=JkT`;79_} zqzjbSocixIAMdUh9v&`cBHeM^Wea7sbn%i9BAPC%Klj@IGeiYY8l7LL0v?X|+7IUI zPgFtn%?TRhI5-gvkkg*Ormq!-Qe`zMzDFrysY(cP2LuAKQimd#3;vb>q>IMc>B->x z14Ld8eL4?_kPW6t&j3Anx;+{BQ$XK z%&7Ir+Ptrg`^mCzSPuwUq-)HF5v~>RbiS<_b8jG{rY{gs`DFP_RJ!b@(EcNA2x|b& zY95w^An&ydDR0ZsQmZ&=U$5QQij~8q2Z>vvJTIs#gNePH0N0zOVpxIRoK z=Oo#Cn3mDgr_!*0S5#XRJqI+5aiFYL2N7BN?F%ju7qSJM?*f9;^5kf@kUpR`4)DkbSZ`vSm%W8Eej-xk+RiYTQve-uYxMyB zX5om9s2!oBagDF3`lb212D#VpnFaxwD5s!h*dBoy1y~=Z1wjfIE(OpHOwZE-k11O5 z_H#6<23h6tUmJY3)mHX08wU!Z0DV+BK=#K#ok)89c2Z)GB+xh8$wU__o1Snn%S9mgYAI z194Xvua+voYg%I*ywH?8@lWyK;>-XjcVSW~ot4l>37Nt*XxE!D0a*VW8DR~%7r8VX z5gpaR!B!T~f=0eLHyGeoTZEe#4n8a{p>=XFFu{Hc&-#NmB}54;5!2{%tTJy{PR?y~ zUdOTW;G9N?Sp3Fimy1Fnh#`XWo+h$3>=4^1VSgdm%;+hX!)^nQ7)DNJQ+`SAfH4hq zM-#T%8ZatFxTG1dZ}~c~RFcv;1~wf$y(9jDPF{+;l#EjN5|&RRubCHCY?Csq_deJ) zXoe}3^tVn3R$G#}*ptcu)?|F}x#wYV0b$6}df_k9Ndxe{(oL(M zt93GHdm0f>Oegq3DCn!}Vzgq^h3~Cp7W;6?dnrBP?w%$#)P`Fg5WLOby-DH{{--r_ z(C|4<)fLpqWgs)r7v-t#&SDSN^Ju6@ToUe(_x_%$uP0K>FGE8|Xn7OU{!$1q29hR% z8}RsOh({gmEOU^tNGE{d!G&Vf{4asQu%g_gH%kZ%91M_NrXTj|XPWcdm z>o0m8Vg3in2kg(i>Bp3CPWvGndk~Gj$~~9{toT&;5GKTwN=zcS8L5~|Y_y?c1_ehP z0G~wBR`AT-WP|~fYR}+^m%Rsb6q_Ua$JACfp0?3TMhQ&Cev4kZu>CtKTXzL)X7+$p z!o7@i;G#kD!1K?+0@?kchr8?&Gy4xE+5cz(9#e*K1)iZByhW0+2`?nW#9;8XILU59 zHic1Eapr~v6l;~H8q7ADnbLP7;>p?MKb&P`HNRn!NRSk;yK5_NB%Hfu0)1y}j#bx1Y;J*Q-jq_!=ALwwt ztWR#M-2{aCCzWB{+wN~#)6SgJ^4J|2vf9^=*Jk}XyW(DWI6{Y# z>&p;0_+snFP!pEka9^$A8E}MfYmvf^EPG%_LdcGk;!a`${=kmXZI*P%JoSDc)taas zWj)CxThNUNRMgFLqw!#n`E$k5LCnRS^9^$QKZC`QjoiULG15Hh3R7uTwTEjSD%ZoUZ ziwnQk*Q}jSuT$t?vUOgXU3Y7DRBezXSY;AXg_*ZM*lOCyX_!hC$kM@~+JK7^JWac) zcXhj9(FRBCYGT*ucxAw)YEGR5b> zdl*?|NR0Q?{{yjlx=T}|Gv@ogOCd5P=}l_%o(o2|5UE}-7Op0uqTp%QV#~a>>H})E z?#wLjaG18f(>z+Mvn^S76yU?FcSiX@!UF&N{RU?A5S|T5GAjGza0QMV^wUS$BUGwF z;Y0h%Qe#SCq%h5T+EVuCt@swdfAl&V+?IQRx#A@+sa`%Nfe(kBtw8(oFE}jRW>LWf zV|vQ?O93vheHwLpl-DQhl;mIGWsZm^D%7a$9Ti;I4casenJ@vBJaT zqO8zNai<{a!;)~ovisb>iZ63n3`LdB4Yvxi^f};w>f&2g?87on z&$@@R@8z^@W9IEmFZ-q#2pIQBvuZH|znoMsC3yAKevBW}E7gs?7uomQ14Wz;mATBm z2&8hxvWI#+hJ$+j?}NG>8K0M2sQxQ+r86(N$g-#ByI$!XwF=E(Op8PhsBE{cjE_>I zt^aVZl{|;*cQ>Ua3N!b!mn*LR0SA|6k`@#DRo;Zn;F8C)TJ+cKGA8Qwb+#9KIH~a5 zk}69cbnHB{x19N8o}*h3DdPM6BK<^p&Ur>LeUY4-JR6hs$uGwDHA=80a&G?_odY#An8&-W^YU8S99qlGds#>%! zXiNJYa9=_9B)IVn<<-R|F8Alvh7eE$rBb1Xa9Sg%HZ?tY!5RHpJsuk^>q05c!FXTU zh4RYL{`O{OyT~VPNtzD(wf=}JaTPu|5f|9fE&CFT+*!m0dEzc5{5dvK%Z&YD*%2b) z*Y0S3T}}O`LylkD*8S!XU0-**t>@=A29*iZ*A*_Q4JdAL%2zr+UTCti+taX075V(C z9&DGg{P!LGnM?$9*_8}@WqeQNDXsakPK|10wMI+MJ{+CQE7)804v?sy`$j*mvSKVG zMDyDtVRwA5Gg09;rzz%OPhwEzj@>ayju+D z$FdR=xKlJrlkyAH2ZU{U&>BYu+s9v&_9TXhqu+^dVWa8qMS>9OeVl}fb!xqJ2I}mr z&wWbk#B%VRp@`46aA&Mo*v@jlvETT~EB>)!oxP2*q87Ecl+Uu;r_V}xQYkHsnk>6# zzESc%vF%J`Po}(5&pg4YBtn;nzt7z99*2QTb3|D0?LDQ$j}k6xA(k3(FTAxF?~qZW zNrSE|+ac0q{%=G`;p@xS5|tY7Sg_LhZtzv2=Y^l;&vqQfWgEF~6;yR4nN)D;<8^1! zof%;)e+0Scy*H75dNvbGi9;?Mt8A)T!f6d(jJl2_I2{d?3U~f7Fvz)EZLAf~ZurVKo}{#|yG<$*(^6S? zd7wwsW#EGw$7tolEK0IRU&~|Pubp-Wh<+H*vC_yN@A$|ux82&Uo~~&sszgM~$I2q@ zqk}zqPO?`Sf_bwk!0W5J5A~yrwQV;^%$@dIIJfPM7Iz-w{vX!fyo^+IW<~?2P(5-;d)qWj7V}4WvX94-3WcOOzc%#IHUr?XJ1%T2G0tJ{#eZh+8yBjiVViRvjt6 z|Fd4k_b;cN$otqH!ZH}2rH4qzle!}sNfm!R$yzsksRXZwS4$^!@a@zEuIuDInUhrp zN%_A%g`7jFyymsC+&mcN<#RcnvrZLpGK^`7xb!4#xPYA}o-5AXS6L&CT5Z0kFo%*| zPRqZQ@LKCsrLVS?QNMNc3RhKWcJhwWr#y$Q&5VLU#*z)o-FynYk54GdlesLur>%Gr ze6n~&w>0D4yd&rG;DYf5FZQ7#O6EzA5K5*Ix>EQ3i-&ICiUx{Y{R5}D&C6>vS;X$v zdU@`%R9lVKU8nWxmv@#~EMlJ)|1yNzViK0MGw)54;oqZb6b_ znqK4g$4G>(h%31eVqPK$aU?a)E?@cis;an(>CyT~onDa_ya(i2521QxOisgir94+x zPDqr0YP>huX{=duZb!NmgMke9_^m7636|X{V^bRvatkPo;)^PuU{PBc*%O9W<6B9E z6!CQ2Y(b&?U6oj`IXyQwe)jSFG1R*jwLC00CiXEWY@Fl=erv@2JNK6&OrArC%cT1J z1H-waso@;!+KU4g&Sa@Kydu4>&|Ry@B4L(Mqej6zN#FF5BGey#Dz*-X~lU6?Ee_ceqwXkhMIC?m>R?AH0e&qH&-d%X4 zc0uS%2q{h1w~|=wx_>as8P#AXQnuRRr;BF0Zk74`npUw_B|aKfle$Oz6D=`5pA! z)f`yCh|lzc#SKw`ACd7t=XpEIB|YYrV-+Yb8E{MZtbC?+?AN((CSm7_+W8otCl;>sg0Yk8hx2j@HCnK>yH~A27|OHrvqTD-E$Eic z*1X|t&snI@2JKC6_@o#Jn)L{57PNK>{qpXJd^A2df;KLU5$%lY`U;?;a;~7rdkq^q z-K?*-7h2_DcXx51D2~Ip6e92f(WljTR1`8Wp3+-$QJO+p(E_*Gl^Cw+HCZS!m#@30J4kw}V%np4NcfE7loR>XWT9jz$`qer za4FqCa9NgdOS@C`RHHOoi)+yTukOP{9)U8WqF9>oFKVgSxR&!@9&P{fXtdJ#ZB%ME zC2d;knro@{`qulsb*}GO%HxR_FBi($S~L;dFfOC!TKhJ4qkl zM-w|prwKQ2ksG{LdAa4WTmH#VKZ5Cs7_>h%8vkv>(!PUHY$l#5wpch%NO7a`$gor| z$1uR(>l-g2-K8kYkGF2(9%z%YBw$`#{m^yNdKLX+(<5bDogNbcCu=%{2|$>wSGn%#+5K>yUP%)7&K}t}yQAE43YN-E89@ zm$qGp9Y5vg>6N?_@%ZJXmVBd%QzJ(978yM+{@H=T7tL4iT7XJb?9*0yt~$V?!AG$( z{GMd{?;u~NS?&9-)2Q`{&IKP~c-O)AMa4LdyDOc;f zKTVdM??o;UhmppzlU(%uK_?=zD3<@nIQ?E!ZF{uFtlvqF&faQ}o?v+V`@v7@xv^Q6 zE88^YyI|&Z?@Q%+3ijUzsL_lHv^MuR>lgL;K6LT$kdL6losHoQ1TF6Tvv?{JenGm1 z*XfFDmUu~D^?%<4!|k{Er7$bE?}H}pm=p1W#@-Vac%+yo*-*Q=W zQD19SOTTbAp`p?lbU$(2vN}r#V+$1`6Wsas-{oZ$UV9>ny&rc zfak=BVehAb-Bjb$pxEcTlAW1XpW7I}f35;eHQDBr231FxamAWsib`xiFG@c2rBLU? z$$=_k3lA|z8{?;2dml^CoukNS?4c3tYmp?wRmSd%1t=f+^f&x@-Ak8wM^uz&54Lj# z9!#_DF7Ve3i7CBP9I*RotM*#-GXo>nza?_mWmp z)CFr-Y3{>;v4Z1$L!MXuj~Zjxc!@0*pc=~bFkHy~PIpYYpv(E@3YPgDe+2%Rv^w)9 zXj4XXxW6Yshnx9YnnTO1-j`3pekiV6%<_J!=>e` z;yxyXue@mZR2#qU8dlL>L^UN%k$$+jubm;Kk3m^kaSQz^$|p{ZADnF)NhdQSZP-J| zy( z_vS|Zr+hy5ou%F<9AX>4s1kgiou_E}=(m&hQa@gJsS{?YL>pbO!g0qYw+H!V>ux*i zHzvO47J$VJ*ZzaW9Mk~H;upQngP<(2`>oUc(jqsym{2}0Czn-S+q*CqlsOvQoJ3## zR}`r30^^uUs7h6EfKZ_fp53~MDPyxOs&f6s{{w89{9hix{{#jd!+RcB=$Yq9vFj>N zFJxC)Sh<|fqL}_FmZow;k|mO=Z+mfx8DVc(JXM050Vq~g6*^?^13pci0BtUPpXH_n zt9c!)VWTW#}q2HvZfB$H!WlGn2b!Tv_fowPK z(#^CYt=2wZT4Oz|)Ft3ulOMt3F@loo2~*$KD?S;Urp-6NVu*(5mA}?Cz_n_8)XTYZ zoy>jxI(BrS_wcuHDQ_em6}{~0@bVw9s+Obi?IZ*Xk!S32EDcmk@M-mC$_N3Fo(PIm z4af_c2iX)7K$R-ylu_fg+w~SbqWs0!=`?V6vEzt1UpwpU*LT(e#I!W~Qbrx3Cfvka z{s+ab;Z&ONXt{kdNSYP*xAG+!Z+I3@mwoZvdcdUx>^2u@fd8%(gCJ32=c()AoJ7D~ z6yx1}cemc)nrJxC;i@RApI*tkM2Jd%{mBP|PX=lvU%}=?FXi>w zi|9GJ5Rqzv+l&V*D7OsQ2*W@>#C{0+eqQo{7VzlVWW%ZP>NHi6WBu3cs1a;jq7SnO zH3v+p;b=JcDL-x*7u?EOpaR63#$;yM>0W9)=o&k~_w;ju|J5ZV6^Z-ugBBhCrV^m? zl0Y8BJ%0C*kov&A+u#IDe3YVJ{&3I|Si~-6mDO@5x5`s)xWd0}=Eo=)2?vTZymyw^ z5V8%#ie12X&4GaM@!Htc(K_F1%Wr9h3w@az70QoHg@JzF9ku`duGKiqBr+cO)(?HV z02^}wk>b~WxNmwMk~`tYw|$egobLl&gK#>pGQYtYs^x0E{#pJOX#bTmCV^X`CA4U( zt_0moyWiJ9@f5b|UZ65sErnt<-QU~kwx~gzn{-EOv$@*ar_K0;OuPv{%o&%TE&K)% zGu5&1bxOvE7)~(3A+cJf6^^edDF_lUZ7=fF>saqRUg-IV&1=a_9>Q9@1S(-9EDLuJ z(Ro91%BLv>XHmxqlYVq`AOUcv&l3di-nyij!U|5UBp{daEXXH9$XLukk;z7d)yaCp3I&zli3oyru%2&S428HncTXHCy+^`o-sVKom{3 z<>Z-E)pjq}QmQFZ9sM$LCV1u6L#FRyWO z!={%p5m(Ni91YSc+Jvufe)7WA`ec2AM2BTLBe6KWVA?VercMzrQPKdZG!7MHA0;+Q zxUD2m>b|=6NH3^ka1XAlr2vD&S-+!1L8@=!Y*zfO4E7?p5#%B&}EfR`S zK~KWFQJeoiK7p46>99v7fhPbzSRv+&c=J%{j_08;3j)*UL9=}y6 zo$aUic&4wU7SI2&+2U|J0-l3Ur;MCG45nw%IT)4mm-ItejaxWM6xB=43yZX2&TUXk z7{6l!6@MIPhq$y@S;<2T3U2M@b(%&Sg5rL(_uTUHQWWAKriO>2aeDV+#i2j0x`)oV*%$INcr4HcKOqr;XJU6neQ z2QvGmhef`5%JqOtpSFv<>0^v8r9~5Lvk~>Pj+G4QK3-D&(Ly++4|ok1$o9>B*B|-P zvdH-5w-AQlmcB|8-`XgP2s`g-a_<-7u2ts{XpJZ+6%416T5}Jf>2m?uQiR1sk5^LP z^LPtzf_$dovudB=h!c2p(e6%8X+m$N`9t@0c6*Ptr0|($wLdF^^V=^N%ujB-G1M3> zwkbOQT0+CD?x_F`5nYn*pM~Ry_*rl$BGslWkC${i zW2I|bUQj4l7npK1k5|O&uqmj)$JwgA>D$LuuG@?Ae0m>EbeDhBr1Bnrg;Mu3i?m+v z$w`H1o{`~cjgP5(C;=OBDCyBc- z?MQxP9Cv~}oBd-<3t?vhPdw3dGQa7&7Q!ID)xm%OIR(~d`RDgEk=RH)QbulY+W=S? zDE4p^(iWGJ1!SklFA;!hR);dQ^pXbHIQcD&HvL6!sI{5VNVO6kDRo|0Ra|Hv+c z3rKwJBlv-6{d??yag7EK-H&iQ@2w^xOSY1iZDua-a(^8aSng3@z>17-AUGVF^1YQ_ z<9%ipmn@if!wi~x(~?KkHNN{sg^$L0Un{>zdir}vBWj~Qd8qth*Wh?X{ZW<6d>3Wz zroR)98l_OiooZK=liS_tnzu{j$4D61X1}_{S`U7TnD5)=Gy3d3!T(WTV>GGYuaVV; z^-aT%p+-YA*Pht&e9vAM`y=}W-!!ww@Dop9+;uy>UH?aQbk_>^3lbfzU{3k2dvJ>X zY$043*8PIJ6?n6l+a(PiH(k-J9~$)yEooUDbsV3S7C}ftC#*?>uL|Mz$Q9ubHybCP z0)-8dMaX;DaOvJET5gn1KvzmaN zfVrlV2%CUpN(uWdWs10*zztN`9SqSA23Mgc8twjiL6+GTBxPX z4MgeEx<~&KHA8;Ehsolw&=X7YN%CRHSEAH0P3CZyT;a>eWx-MzuJyZePg|LN`Ugf4 z7hN#k@g+Fyua91yrd{gKvM{1h>@n8?>m}jTlE;awuO(Fo@e=H4KHQ#KmqPDF=S9S8 z?XDo8*x(HTu@-v_f`J5`Y2E#+@p$ma^rFF$8j)3>VEvN~Hihx>#J&2TvD%q38V3#% z+**nusZ_^;3cYw?+-s2+uYLlBT%(gL0T$}e)1KBW__+g4DZ0b!jobvQaN`EH{Sq;Q z@7A{wPf5DHq>&#W+$*eOrF6oV_deSBIF^MmSFlT)U%2gkRB@v%d=GCBPn&6D_2|#f zn%DHhiEumKM-rmh#K*Dp|kb!?}2NdV3-NE}sYU4Fh^2O9oK+kC+eGHz}5T;DE{4U$4Ld zUEQLl(>BB3~%{hI_VETi(?%y&!Rh04vGKz z1n_sH%i(Y-ZccyVBziSRY}w5_Zlig&u$UAa%6&OF^xNWXKS@@%uH^XClcS?Y0W4Dm zd7=eIyRS0Wm?|_*j?~kq*Y4fD_|w>eBvVS|==5wqJ%pG_Ia17#-S@U=GC&SpsT#?g zDlcvg6zg%NPUMBL{u%J<&`D$qk~Q7_ZGFh5D=$$!RAQ{|^|)k3;#&9D)vV>`jcX&t(QT6@q6+g{e_QN`nj{`#zpfBwcc^;F6_H$cf0qF)^3-dAtdk0 z@Q_TF;_q&{*p%;eWTog}5yu!wx>)~A>BOs&9%@k=S?Y)K&*oaHTpyaBM_n-I>M0fU z9!fx93Fn#1>umpmCA^1dRW8#4EMalj{G$vs8lUm+h)XLPeph}mc)MB050CBH7l)q& zA5G@!hgpvwS}e3k_xoa4h0AR7zeW!YwZ#tNK^T-R}HYePJ{}nKh zd1kS}-;ndEZohs>nLNSh)dXU*2%$aMhXUu6(msJ6ovxQ4e6wk#mNs8DlFI;}m0X-$ZQ_YsLtTIvV!gHaz7)X76x;B!3t6Uut zljyHaq6jHFZadX-Av$xA>V-&>X&-}fE#0ywMHc~9z_v?Fz_195SPgd=DN`KiEKxaU z;ZIgxoG4}cUKt`uivQw!>~04R>pS&+Qt~E{$;1r}cPu0=dAikp9|b2^4#c%=ccL>@ ze-9OKh~oJhK*UtaXNOTe@cSNV{s9wIGFXR+?&L~(tapIjIKeba`2~vKq=ti>zk1*C zb4}M2G(LZCrmV#WJR|g+WxR-)15GvSRMW!B%~pwI1uyk}16<6XxKb59sBD1Q)=Pv*HO@3Hsff@?kKa!ReIg}-&?kLPW4 z{OC2Xi9TNW$-{G}a`SoULC<2S_|vuW;hL|Lwve}^r}tJK7v%LGAxHbh^MHe>gIAgh zbl#4wodH-lmY_*RlGvBqIzXlV_iB0!yxZBGW&57Mb-75Y=You3H?ad*r(d`(@oRxs zn*&e}!}c5yb&5cW)CAMjUYQ0d2)+~f>}}2{{CdIk@A#nUhJ45t z#_a$zk3fz_auO20{_Cq75(tjOH#L%O@)gx!X!}%Hd^r)o2F^!DznvW2+c?Lr6sZb! z3BwPF>0iX>!1))Cpw63s>w18j4ZBF}`5$guCGVPXY6)*ZHW$$aBe3!#Ad|WONqGYy zlI_$IMXwLe;OvvPg6__6&GI#c7vc3Ss$h(?g#6{Mr%-iXZ+EIYRBn68QnZ#RaNCGO zjQjzgQ9eNk{nKAzVtb2jHkYdU8PFuxP9Sc&3!-$M(31avu00-Pt1GvvOHIY6DvJPW z%)S=Zfx9vzE$d?S9DLhUilkqkEVIKE)(vCyg)(^mVRGAU@f1JOS)}&w6CMf2}Pdd{`45|xERL+=RiDc3TvSz-!hL8G6Dyh%_ z!Rlg#7;tBQOyoD!!#$TC3l2oPDb+RSRTiztJTGTzwuU=OT7dYyZ^EiBN}(Gl1P9@J zdDL0h60iRIC=jn}FmYxlAgN$=CwsTuAo+EX#jgKNAV4ZgfA(T8;>3moZ8cLUuw?oj zebh`5mTS)tbDCr|O|tprW#OMu?l5_4Qru2io!#MFlRsGd9>$Y@S$A*;+471B3CbF+UhE1zz3669|N8VWk z>4aY%;lDdKtYAAmsn9|x)6C-YGQ@b|*Z4uN+?XNEtX*j&C4Ew5X?W0sllYn*Tq+zEGE3EAI=lV?dOIs%mX$gEl9u%o8LJ6=1Y`8gyM#;Z2W=5Ug)pe7{3 zM!x^u9X5A6Qe6x(5S)HP6;p^a9he!q>4C zAHYHg(@~X0{vi&LoAQdT5#@JkMQU;$7iu)jt)HS~;(t%6(azib!;l;ScrOocLa}S)2Xc0{!j_0? zubBucaYYutgSWj*x}4OjD^0q$A);V?qUmjuprE2UETUZC@xFeI;0pE6XU*ZF7gb!a z*`7TG|In88mhO-B^v|2?5>H-;3Qpjg*bs!voE|8c`5jtSTqAbD-g-{J2ou1?=6R(* zz=#DrY24xcGirBxF{pjvffn}deq0-u$IAm;H(TW!J-)${xLL24J@N>iYZuoVYgnJe z&=NIb)~oTFn#s|sxh?yob+2@UNz=REibpK*hL~q!S}JZ7OkRYivvJHU8cT-K8Lbs# z9xcPI`u4F;!Uu~vlw|S9uFA^yS59>wmoEY|tmhJhjeoBCZ}^9N^LjH6qu4fsYWGJs zXJzG{;Pzf+d?HTRrd3p8|6AJydj{+BxiCD*#*zEV>yo{d>C(QpAXDQcUA*z?`pPp1 zn9gQrY--mMLopsTkcy-58C4b4VzzO=2a1X-Iu+lJX1beRF4uTfRjTnUmdNJDQ^IH( zcYhwOzn*0WAM8A|Zh7uAeuc2bvKpAuytemMRP{xB#j8q3H(MekZ?_!sX45|4uUw?~ z)^IYzujR}#t4Y=S{>?r9@M53>rUbZbT!OrDzss42Q+*D!C4DDCekd*(ir%a`F2vGM zNSEcDsm7zAqgHf^VOMs+a0H*7eM17qatCiiHGD+$_%G+Q3ijmOU&% zUt%6(UoL|XCgGXD;APb_d4qHKS`=1+IOotzlI7uPveOOU1A|H527D&^%kro>LM}t9 z*NSzwTYP7Qlz96OMXGFdbgQ^)kP2%SD|;Z$2Jk;Wo{)!AWZ!u!qsS{93hU5AQc2Qv zhLc&igWn1e&=Fj2H>dMJ^h$l7TFcDBW$&JWQ?+w zL?5rHhBaoTv!&M^t?-?g0?&n*GFS33f7O82abvFYl6h9I+0StNXV?ULu&_ewA_`}f z(meoYf*fBQJ>N@v|HnEznjBN;Py(m;&?!8rDfHf}@Q32hsTB8|J5IZW>VF>NP)JY6 zgqv0J!BV=u{%FT8+k)+rrJ~=2@9=aTVAZ`ZXC?@*sct$=_>1?!T{|w>bh~ts1?t#ssbgDqn+CH!&%4gcuh>o z72(NN&bAEr#s82CJBO01$u%EC;$8Sz;n3Dnw@f|K3-62C0*`He^3QK3xYGABx3zw(YN+Z0f9X9vkB-KqY zBbrYH#_){3ErU(>6-8GJ8JT|GFx%xug!>G5JKgDU-*f+FQU z3(>$$HUkj;m)rKOD7CFi{SGY8M|=Y6d95XJvH^q#O;B>Xh_|YqHE2mq5(X?0G7m@8 zc&@fw-}uCy%!D7HOGbg0EliI)oW&Z6x#RqX@jD^BE(Z5eNU{-z=hj>SM=}eKVYu=% z8KW=t&PxI!b$nX#bN})Ti@f5k zc=N{7LksT7aqIGC*_V67@j9kzy}_uzvJIKNYSzd7C|X>?F>6 zDvSMb#5C`Mzhs?!S0mQ7Soqw;C0m^vC=S}Vw^Cqvt>l95#4Oxn#&-WKbKP#CL?)Wk z8D+a4Xz}P-#{W>gF}-ePeQRAINHFsKbSZnwsF(xx#$^2KB%MdgmJCc`QR)47n#kKe z6S|w^Vgve+Of1HaISkq2B|FKq( zG8}vR2M66)Obcyyk6K)S2Wl01e3eM!!{C3umL|g=@hI)}xwXkvKI)Mz-E>rvm^TNO z21Rj}tL;ZzqOHa;WEh@R>biOf!anUTATFeT#=#K|MW9$+;(v{M4b7@5APVh<9x^tORAUgTd5#198A~6LFwU1c z>u^rHf?h#d<>4|Oj~0Ke591BOgXDms^k#VeXunf=xjxlcx*NkW&39Zzc_B~Ij-qN9 z$ZvrND)g6uZ4>u=c&Ab#q!gv>d5^5PJDD~MFE{XLX?_)Eq$NxVM79I2wl3gUIW>~_ zw_EAVRe^>>USB1biP@+>6NEG~Z)-+eBI+23HiIPkU32j|;NsOFABqzyZ^z$2Wg9I1 z{)D_f3PQmSn*3I8H5%{PaQPW|3`BWLfj!Z&L?HoKGDE+=TPncwMH_d+Yc-)igg{rO z+Ea+o(gS|86J9AZT^B{emt?%=<)sUj+HARCya9*{F4hoU1T$g>yqm;lOlCnJ+Sg*g zf7BYayC4a=EDxLd)5KN*^s;qN3LCT%@leCn6#EJDn*3T1ZsbomhTCF(aDGCSH4f!7 z&v=nk13YSApyx!q(rU0PP8GVt?jrC~BZ+D7ilil>HOQ*%hTJquCPYIs;v<>nSRPO~M#t_!eP6Cxf2h_fK z)rXVRAbur?hx0)qkU8B~ML)55jnfWg1oo1nFT9%7+W2>i{}}q@ zf=H$=q)FB@f|N~uASGjwd^dP^3Xn^=Mr``<9Aj(b z7@$qdtdIP62;ZV~KZCp^0!n>;c_hT(@ydh<5)Kx)}aCQrP~WoLoA7n zztC)+9P_n2nk!#GUf(xaKGUJfx`bTZ*ae8&yL9vE0VI^U4U=M@dvGoYyWU|-gE3lq z;Cjsab690=zm43Y$nOUR<5C1sNhpLp+KkBB>m1WveHuD`6KZ z{X|VMDO2;5{uOf=c@s-z<04|&LEdOX9=N{n`~xAodNCov%`5*Gdv6{N<^TQ-TgEQN zk|o>NvP1|OYj&f>)?z8LW{HelSz_!vp;SXz+GUAK*~Tt~5?MkFktLCx=svIde7@h` zb3ey(Jb&K*+{e*hHO)2G`?{|8b-vE?wG@4PZFAB4`)J3yTH8V%IB`6U0#UOpWl)#! z&&OAHRxN86p7-0GRB+%dvLdnRjw%y{=J_4_S;)-_PnlTQEpDY7(zH>iM}R+BIz034 zwQyojCJ5)-2fpk1Q&l$&e64}RL1>5=6iQ&*s2nJaBnw({u?nrp-NK1wg~_?D`Y@=* zJf_{Vujm7p=+ij>>zC$uCwxEo-++Cj?}JV>9b)$0oshXW3oVTe z)xf;+q+|ihd8;Tseq7NtorhSpvuMu&O6UTvDALOx^k(Iv8-1+tBLjE=#f{a|W)~4m ze`u_%iJ@gvN{`6HcC=SdiQ_hH>kiZO;KnenDOys85BFeppuF%i4bu<1xZuE`(#C}K z7_3eaJ?lUOTCb9$K#)&xBTrkb>6Ny>y;iCQMg}9ooDbgjS}98QCs7|HgVt|BCugx? zJ!8q?dX7m4W9(jd z4i-$GF?&yEVObWE=%Zf=EfKt9dQ*y$G}dk4M67-A&T4;v<=*O!t_)Z^0OVt4A_OWQ zf~^O=@=6RR|HQku(4sGvIybNrEgI9*Y`BlRKrOyzOv)ht46tBc7L1Shl`+MK=xf;D z!9Qf{gAP%dFhyJ-bh3vw689chN#<^rUAPBU_wJQws#>@~W*3{*&Sfty1XRu~*u`ut z#b6ckKGBp@XRv;VxZ7ZSa@Zo0VEPRvhDWb#&3F=#8o`;d#rt6LeNr%<8H)$<|-t5>EUc=M7`>X~>MJc(buuH6aP zbn%HalW9#3QOl8)hvq+#?T@t5>O{~p3HI^Y)h;$O6lh_4lN_;x9L1j3i#ZNhqXFf% z%e?P69BEUh3( zQU!9I<`+d_S#$?x+FWtn1Z;yLr6*g4#&D(^)8QWrKF`NT=2|xoS*}YLFytk23DhFC~l#yT~)uVLm7Xps4PDNghUe>gQVYa%upV^11O^85E=vtDm zW6NzkNg?`Hh)+3(EnqIb8uBUIflRM)~VkZ3vq1clOisb!PD3(anyQ8GD z(`Xi&KwI;4k;pbrvi(7Rp_(KL^#`pJ(dMkH(DZ^v3lB+wK%MH=NGB9}_q*^t4`6t1 zc~3?)NkOWDe-(Ykc4{T}yu3P>g zYpyOVqwyU2+&%-mafRlXyro`sei0 z72ZTI-!`S1kdcwv#CD0p{K;+x!E-T+N>xJWe7p>F=kAOT)AdxQpP|tqgz$cmki_=K zZtpeS?>8B_1wEAUk&CV1Z=blh$co3y;O^)p?zbordHbag{!%RCl7~hE>zDG9DW=b) zCL@?yx6>NzN@hTFO%vWYj}zx73EuK zj$ATmRHL!=OqA~Y`1|P{1Y4cnmW9$MRBc)?3Cge$)!4`Bq`!}5im=PJq;rct4@2t^ zc{R>bu7(mr6?HCnBBK%CiXq`!C*heee?aAa-Y~K$(qDBYE{+wuoYL!j!r#YvzFYre z{ChKziK2$VI9oDXZ*UIPuD#_*Imw|lRsAr7y-Iu$&j-rhsQAcE;S-GCgMGbYf=boD zJV|wGxUA%ir}BX4it_=CG$Js{Z%;R5`r55jS(t7u z8VS7PQ~wA-q4)C274j0D<2gjHk<;p)zQ2Fudw22j`pZmO^`=Tia8jN4nf^m62EMnZ1eaUaL&zUg7(F&fmb8Ggq& zE;)96HAai7WWu?iBKz@(S1NJCPb(kdTC^9gg6(wP;9$!j2Yn`YZT{#{zts!HyOl@0 zTRGw{QjHglf2az1Hg5z-g!mmkWr2@q|>lGY8QMj00HCG zfxA4aEeS^+K0CnNP1Hxh$3*F6nxC8`MsHC0T@gd`A|DfN zO6uRb{ue}hd&NJsjsi5kZ;tb!xdjoGsI)Ci$;6JJ@dg*hLE+DT(0H+`pW`_C?#M?C ze&{~D{}YCdd;bPZq+vI3+6Sz^=*A#EtsC7`z%gT*mL>NYeoKLsNTu_E-?Du54n1 z*%#%ImD&~W_#~?PmnI;>)%ZGJ7y1^`2WD;~J@;7Av%-<-V(%Z>aNevrDE6S$*g8VH z?XgJaC#~lsT$oC8&z_%eOInOV`!Q!lD4Ecbr;ps3KezIf#;`+}7UHywF_IOm>}8Ua z>I>)2`ajb39e_n#z{C|v`B}(epsSPcf4e&W=es)Z`FbHY<LitL1+;OL!rd3RvK*2 zd*8yr4}fD@3CtSxf!_MfIJ3Ji@?-XqtKowX_F1zwbW!m62MFeQ1g!W|PtJ4}ZUsJ_ zWZJISr+yyme&N;hXfQx*)(~hzJOeb*Cbxd~!NnQ?c#pt{9(WD5(7_M}2imZO9uQr$ zgXD1d@SBkSlS6DzrkSeKtkR0TF2F1Kf}KEW3W{)?#0@jxLg)WS10*~*U)Tg8Ad0k@ z6OqiI#{l>M+h2qEE@{q(8Qk!ZcKc1NPeGd(5oqG0SLS3JzCe1_`i5Hd>6Mrr>y(>v zM5T&hKZx*RqbD5p8fpilxJl6GgukNGw4H!4@Kx(Bpk3{+%iaOpn;j`Y0hfyK_Z@u4D&w^Y3Cc2vS&K*J!V zx$CxysomwTdtZK$DrtTLQ0WevM43dgY|W2-(;?h`^M5_zGyL~InE?D}`kvex_IKzV~F)KlzVoe^Rc_w zQ=8}4AfH83!UPeZK(zbbWOp!4bZ|0!Yy|DV)A-&J+p~r+yHqcBtH))CNSnQvXf`FF z82pRDVDkom1H|2NV(BtRp&tD+kV(Dm$Ue|{kwQoCro<U%zL=PLXULdp5c&^~-?=D~C`gO*CGFJvo}`HA3j%FbSaxoraxtr_VTLtO^^ z6EoqcwKJu5vd|sM08$dt`MC#}|5lUW2g1WnprQ0SIQg6)acnaw2j*ej&%Bz3VHn~8 z$y^<`$A{dhi49rv90?`u#GgojhV2{hwDn)6py6Yyx&u8%!G)m23T(qvoRc+-*2lIP z$<4r1nEg4rx??+?FtS}|`NHiCvQ60XrY50|Ym;*Q=#9c*PgB$VB@!pE4T*amIok5E zX_Y0wFzMhvADW-4!GA^HH2pE95d{_CUV2Q}TfEq!YO^tP`kJCi@Z_LeEirERMXnEg zZG>;MsAxZ%eju|bJtHV7>W82MA}iwJ8CVw7x%Khr!3GxjX!bz7aO!&&B4^RH=%9#C zBZqX0BI%mz3(Omf3SfK4s?PmGRDxRP!@28?20YFEZC0Tz;>zJL1;rw zFgRhg(@tFjh0VoNu0UV9lB<%zA>TB36Piuim-jyXtRX7#`>gtbQ9sb6eKS`i#g*BB|aZ1&PpeSKj*$Wf>{*_7fuLk7ik)SU&qt z!cW7Hb1g!mOPF7L;#~X_V{E-2Wx%38AC?YHgSE0NwW}>Vw&VtK((PGMoKO1*E<`sh zsQP9wgUIOlh%5!pc&2!v_{hakKf;UhiX=tZ_VYhJ6vpLI>qOG)_|uv0(~ZRFu!P;w z3Tw0&T1@4r31)kAUe$M!J|u`Rag0|u2O+&=aQ|wEyvkXWr~MFBK6OFm zf-v{VSRV~Zd-?1JN~1hP?)H=m7gFvv7)+!aW6V$*Y!Y=mOd_l}j;Q4y-_I8Ltf4ZJ zK%M%Hqi#$-_+nU$;8S;e<2kAhhI`P6Y^Od~S+drgRbu#8B}PL?M}ZKtj- zVoz$IDOA_;IEy02WEd?Ptw3MK5M7WN`?5@QA47)YYGX+Mkpn!5fj@*S*4)6vO>pj= z^|c5~_^CCGtx2H(D1a&4Y1Gj91;?$1>%+2b(?pkUK)o(+dFE(}7ajG`Pr-iPU}%J&|Wi#@z) z8@f!4N&Ao!uyJykS2+3y;)DzwcHNYr9MY6+%zz&B_sBmzX!t56x@GhqaHg2H<0BR+ zs$I>>fAPSIzK>(fX)=#728TN6B`N;f?z2GA=-|^4hUOFI7Isf+-k#z4Ry64GN_FtJ zX+)suN&Cs@!+))@wW;=tX-jlQCp8LhBBRnD@v?oCq7v6_4zgVT<86$|Vt8-Ngv{KA z*#MOt=26Yx18+IxPXq(}lj3**9v3d5UA?aEcr1X_o^bDaQWO|cgU>F#3l%KP^>2H|7)Nyde&qJQf2={ zT0*P(?!g(DkRWcwIs$r<{=0~`@vn{YW*9_Td}IiyPUGa~vjbq>WH6lt9^r6|cg7@c z6C2op5trdZS9e42m_iO@k}rrIVfxjj=rs)r--G9>FrbJ`2f_4=EB;*l2ASy*432V^ z6BPusLajhpKI!WT!`;&6ByLm05*&h*R%SLJ?EY@5Kae->!34zurY?`pJhiC`gJ!QC z@H}UqbwP~NJ@_6i0qC*J4OTor5>hV%>F9{a4UZ9IWD8`{rjldnn8=#dPiGG~Hq!&Z z=w|wiG{ybrKAq6ooaBh?G`PRfxFGuRNU8E4ckHKhMv(jkOKa~$Rfq4fdExn}5nPIC zv?N}Et@zB)^CwW4A1wYZ1a>bVm0P@=Eq!|B3zCEd^VN39R1*j8n(y|S*JWt?enlwC z(x;_wIE8?V?EFZzW$yR(g;w<-3fEo8B>g$)dS_;1JbvN_l*ZLxSD${l-nV9CiT-oXk`_=JK0RO-E&=1e>p*XM z)%;cRPuMM)!aN07xJMEE1*18_%47J{T-Pf$UeZC0=oZ-RohREfo&|ol@%jyvwG4RW zzRdwmBk)IkcYj?D@USel%D=Nd^U(-xuG5SvaD#Ge%MV)$1=4jOIU_N;F3`S4QZph? zCKGbE0UmaMpd6up9=iO~*=qtBM#9W==K02OIp!K=M#Q&0&8=`t6@_m8RMTp*4x*t6 ztA%@J>qDGS_(PTeWPA?wF_oR9eI!mB4rl9YkVW%>2$kK;G?LT>6|g}YgU?})Uda?} zmD6s=fzo#1H)0S!0-(>%*dW1F2n39zp9LW}LajDqJ%0AN$Gwa^pi9$#;%8$F&4c-4 ztsp-br{7<0*ayYNh|3Q66MKB-@dg+-2{)tz8+Wm&XG8hjdLRcrF>84KKc=Nh7STAK zf&&0rG6B{+vBetN*Sw0;ED!M>`l|d1c>{F@Og#MiVCkkuNSN4UNt#c>M`w#(OQ({Z zfjRp6;}Mv~jWGOppC7gB4)I+GX&K1^>k^MEnG)KlAsT#YnXwuDWFpqyk((N6C68a|;DAmvDi3&Q+bYun-Eq&T( znkl#VdoRC)pB3<-m$wAOOUjo!^)Ov+UMO@iGVvh2ntmK6>zsnoXZ){YS(5ouE@Q#K z8nvzuuT}CKFhhvRwn@-lFxsc&cGpi9z#M}7(|u_aVioFP_F^OCdSPqr`vIV~_wAZ< z&yAlGT5hy)_+=kM&ro0qxXUbKM#iK}%gT>lJ6?>QgI-QQz044};XMJ-cbkO zr1A#BoQI~Le?p8lp01MJ{#pDSWbyS|7>LGAK5ds&3i$Ddg-syKOAe6i8Q>sp18clj zHBj6no$4#`6OC|`tshZY=O~v&IoH8+54$S-Y%F2RIDOZ??KcQNZ;@$qqJI}2v?!6^ z4|6Z_96&GWV$8v@m8M5VS#(|zGzAiW@O=6r%aHp0fQIPw(HO~ir$n3J`|v=bzw6BM zidZ0X+wgCQ)|#NCV)@QzcR#REsLFw5X2t6iHx|JIXr0JH`78ZWtb(qmy)B7>#+-{@ zZKBv1cbU*3GzOUy($Ab%uo%Vq@mM49ZW{C$>@K?7#12N-Jgk!ueuB6j6Myo zh)oV%fjaE_sVf3A2_r_YIS$+N!sc<^TAJl=`T?ewp~|mkEL}94_A5{nZdjfVJN?`C zxyR`g8!Q))>4?GDCz5D!D4R}Av+GirJjr08k>Hj(0g`|zMKU-=DEHKc1Ogbb_hR)&|K7{A3R%iy_fW!35 zcCY!jCUFpr3ewOwGCLhuG}5N!&{@RZ_vn}Wb4BvVIRMX z*8Ew7P2>=D*|QuiuiY@*z|sgyKy1+wD7xoRV;`v~hI7zUqMA0xYthkAQMg*9FVLo< zIpSIJF^(C~h}f$dJEpzG-_?hC+EuQlk1F7^Pl4Da4vVtk1P=XD%yqOwlz{M+s0t(O zohX75wlYeRPgEk$YR_CS;(g5LQ4{-5&;xM#GFPt0-w%LsC;kHEBMUSIdJq1V4?Q2t z3fw983aB%fG;RUE#Fee=V%lM!Jw25eC1h^h6RC(s>tH=oS41lls6XMX8*efZ-8X>i zQUoukpTLH>h@v2)s<*LGj`}8xS1ThWY4VC=2%v`8-@x0pk*uhQUrEs$U${;)ErDqr zD=XBV#1JB>=OxkBJfbWQt6?K;#wtcBieS?!ygnb^OPh0G3b_1Z#69|qA_j&>`B5y< zv{k~lP#@?PaC@56u!(z;g=+X?m|GJogwPj)(<7XAAgEz>Q0(1O?fU3PMd|njY+@!{ zdFuyBs^dD8XnA|qhUnEsjt=}zY5nfw;f6i#S2-oZH>m70!G$_H_CFP+@Yo|P!w*di zp1+(rHA@h?)0KLtOB%yYz|pWA)swVe2RDx0D^+}tj^LIEO(}OK?kbCLqb2ZWY>y%* zZXL8$3kc2qVhxNjCe34!MpWDf&K$q97NdX>V3<~5_KEmL5ltr|;>LPM^xnJRZisjH za)xTGyO=r8CdN-7UXsVOvvC!_q0qI+7`+#z4>YoXci)8EJ>yp%>T!xEG9JV($$?Or zf0>l|B$?nRhTk1gX#3bn5->$I<&+!7pBfjYu#fh~(sOXiM`K9VT=9G8>ErjuYw3rQ z`pfN~W;oL2$Gc^l3kAqa6rl*3Q_M@K4}{<9_0uR3I@8~WQ2S=JapUuhj4VH$MtL}z z89Y0&Y3w)8wsTac9pA3`WCIAJ__-aw?4K%H8+P!lOj=hCGAi!LlE3Wyx3&b1Y2If8 zAFghGy>aJQHX#^zoJESYZmQ`P(e#WgWC#y9S;*b0%KeM^16Po#AXBLqf1}YXuSK$B zTk2E4NH@RJct|FG*|^EBQNm-Ir_0Di+-TQuXJln1h2jrW`>hm>E!O*(3Anl?JL*5c zmuc0siG=jTnbXY+FJjctY}9i4y6KG513|%FjC2?sod&Amc@=^Dwe;C&MX|bYrm$#Qw!0nmzkViZx$GhDk-Cg)`l;l>epmSD z-FYpuWrrgfkl(_t8eVhjpGP>N-O8!tPbFqjXkXy|3m(MA7%0`; z^UNG`%@BC?d~eDyn4*h-T?m9S{DD2)^zV(Ad5<1;76xp>BzuM;_NWg#Mf*JAp95$J z1ES(MTFl`A9NhdD1sS(lVd}bZ4A(LCFBf3RmZlR8}`2h(Hd-!I5kT2b;yw-{4k zDwzH~Kr}_JlBs05$wN5XxS>SDOW*{bG12zqZ$*G3xxxW9nZ@~@p#g`*a2=q~yN6|< za1PY@2ZsS}%ba2m|HPmSBR8|NAKE+P$YQ>~f|dWdP8-Q^fIIoldR!vf)i2*W(%3)b zl%R}wcMw5p2w39)6sVEktd+}H zn$Y>S2C`-(bfr6h^d=7NHQ&$ia^FX{GkzYe00^q^_9&kJ@0ETbF~%gZr)vzl)-F(j z`9hKAatp$c0Hq~^#nDtuG8j&Yzw-Th4JI!&WhGA-Sp_xHX8|H@ZGv{x845{YRtoN1g?O*VsqY^nuMAirGEbmwU4wVP=6IZtM$PP? zhy5amg^$5k1mvNe?u4z~)DEu2@dE|9?HSOckg^))K+wKA) z9WT&pg0d2N&TGRmV;Xh8e!fL!IRX0x-jX7zUBH{^mZPFe2WNniu&&vD(>xi3PxlL6 z$5e5LklGRkD*X#xs`Vid^cr*}zVRsZgUlBP+)8DUQ&3^v z1pLmJ6Z9p_5~eu2?=+Yo`57``d%COrX9}DmDIkX{>%~lo_F@*OP7wh!;-#V!w1bX0{Kn`&W~ZjrFI{{@JtnP zE=tNOY`Ief!cce-Q?M+MZ1(e!i{K8rl46b|AvbatV?r?&m{12gbc0ntus3IJ>D35( zC<4+0@9FR#IbK#UFgVMW3JF!nPVoY{Ugib*8PK{p2CB-u3I8z?B@!=v?zO)$EYtG4 zX?Vc0p%L>In%{lGb%5s(I`tf2Nh1S#;(NF(kD{39_)VtZ0X|hjGLBJ9QNyc^d-8Qj zb;XnNW4>X-e^sfKAUy^ED|0vZV0+;E&FJ`9mO)l}>Kzf-v^RIhh8qWC%o;Cenrdz8 z)36ZnARW3t)-L**N+FV-Et;8eYXQi7sl2K3`i!j`4Rk_oSJ8*j>I+?3zhLS!B#5GI zK7}xUS#l1PL}B#K0^{p3#-Q^wRI_7@4pa`&YYhP~)x9+Z+-41(M9yf=0!8#W{A9yW z9}&DthRc8C!VhdmPmJ4G=32CyDMGI<(S(0Fyxb#97}%94&wr!;}yDDv?3IZ6f|ti=aK}j z0G*Oefb6rmo!oRfXz{^LS+@KA1GXfCJ0Tb zgh~bMuq3``JbhxAaTb{n;#njhY;2!`?x|*5pW}170+q{Cx4k=njUDh#4k}yJj6Yfl zzruDxTSuRI7?XuDkw2q<7sCw1b4erzw4ChG74NLQqLX}A+e$au!81p_YZ^1;!Wn-` zza0IXbXgPxqGtxu{KANGjYC|LX4=-eX0Z|`@9BmjgZ1v9J!9YR{95h~2*YTwxi_i` z(w8nKN(BBES^K5o4}Hq$E*oFylOk_KxPo2AQ6G~8{IU>LZmtez(FBi)@b1->Z{`7` z85U8T;%x*!K@b&us@R89_T`iSEq(u;YIccud!D$)6j-5HjCOK6!;tP4R{=NRrLF@O zrQqpU+SH(oHbG+s!h$#@z6t4~;(E+Mno;0xh8C|9Og&s68czWjNM8vr(9@01Z&*0O z?>7xw)G5tnatUbYoa}Jlbx(~MKtJLa+z!y4dpv!xp^AhePl7N5$&Y(V4a-kwY3Uam z9Qj^HnvbzS$I6LbE`HCssCXL2A%j!2fdq-vkt!cO8uTYtHu1JCD+t3S?$upkqPa8i zUfqJ)*1&}MVEmX4^MrRBVKqkkKK61(%c39BGI#s8e6nY!T4-iqackl=x;&j5gsT3C z`%Dm>3r?kSH^(#kyl7Z@uZlc>#DzrBr|r0x3V}ngk(Q)Z!tCZEGi7wW&jnPMqzFZ< zpI5wyB><@dW0Dh(PDouwK}t zEF{g1o+TfF4BPTu`eVZhY&wkr09RfoN&+J>@OMTks!h;jHwi)19{WmwFY0A=iq4!x z|0Ldoj@*n6f}8D5^)1Dod5``QfxtCtU8dvcK3`4_H>LuLJHn2RoCCn#7FW90U9RM6v zC7i9sVXo@kczevzrJiGQ#z0mr0#d=1oPCef@&0V?4J;bbG`lHVsi|_U^yY+K;R$0a z!>fj~H+`m1t7GzdW!iA(g@H4|tuwS>+^D00DR3lGH|T_D@G^>^x>QknM2rW<+j7zc zr%EugEarj}BrSat5bY`jFwzhh9otl;T`H$Tr%;bh<6i1Y^NUEKU2>RuY7h_|bFP!u zRkw@_eF4qEqX+xW>WpKmBlV0k+5vqxHEzHsk2XwHm*B&b!Ny-{G|R_QFDk%PKs#;7 z|M4{@G5uO>)CP7UN?E1A%#C9%ep$jk>bbr?a~jKitisZbl!Um}_JjoPIK96ynPuFR zrr>v{maZrq&&H35(^g>Y3BjB}Yw9!%QWKBp3~}pNs=q%dD@Z{sDT^@IJ z9IxIiN8s*w@IJa?9~0DfzBJa>OKKM7@451>#L!1O>DXecgMD+BOt@C)C@FQQ6|x}Z zJLTutSPk1w1Xg#2c84rOEpiAgnM&hGpRs>>J#7rF$ub#d3HE^rx~J%!VyWv~-g52L zljIRv{GS`6xhPCN7g*t1?DASy3JppUow^KhSH}Q(r?&AYZP!_giGDd z8J#7($S?FYdCPE%3ngX6v$`}0Z|f&(R6eoFhA5VE<%V59RKd5StOmw>4rxWLH%0VI z>?K;O%BIFpe34I8?+d;rPbG|ep1VyNtpFa<`j+3?-y3hg$k;K7J=MeLz1SyXvo5g% z4)uZGZ!zxCcj9LX>}TU`*kr-Z;szUD(M$(o4w{+uzru((5gmfr91D*WDG3vz=o5Sv z`CPdd8x4p_5qd%y1bSu>8Ju;K60eyQ`y*r<4+fhiZ}-kaYKULvQ!=?F(lHe`IR_*Ykcb4 z8-YLG1Y5Ta!^ZNiR*PTYFYrpC&*?PQv*B(|@ZmXvK1SIPZS}W>>{-Arvb?I6L2a}v&+~46A z+6FLJ%%>~9hw}iGg|xmo_#E;8p7jtV{+{OyGFL;!S>eD4T&F(8q+;yjINmbx{+}*x z5=Cz1Sh2)t6k$DdEGBkg)PbFtF<59hRIB-E@T0W_G5BcW>c8--8^L$HWw)n%76rWT znO>K50X_Aspb-5;5+(g zYqRK7ROkhsTVAb;SFXM5DZ;USes_ao&m9;~%Kn|Y{H@>?b@gh>My(YBViJ76kK`jtkU+vf`uXh3_&u3?0=R68FA9k zBjXbo60J98z*OKgAfWBWE5%R>jH$zk)j7Sn{?9^3pk#D0NNOuEkzW>e<|q_GFTGWy zVqD%B^S{*)@I(r-Z$PrE_{C4fmoDDXxNPyv`3>3OVliYzjvYY0SL99Lf4|q@F09To z7$Hpyil3{!5`efY-kWM{xTgIO5G_M?_N2ytJ}hz^bEep;yT+eQ&LtGYanYa0Y2bKu zNYoo&s%OlfN*lRU1Qbm5rI8Zy8-pDQ3pF`{4{2N*g)_C?kgrV%n3Va?*Vc%oix@C& zf=P@R6l#bsvfSBm7vzhP^a`d#PmoE$w<2Hs3J9@|4gF;*g<=)20&9_0d)Q_t9c6 zL%KwtZ&Og_AW^8sT!g0o{1d$(7@>@nyLO9>;h-V8_Mao!hfpCSAIQ)7oJUQWu3d&4 zEre#^$iLTF+l&FOmLVgTj3hMnv!Gz23y_&7qUIrbs(FEnwE1uEr<0-5qwgrpibS9` zy#ZK!h)Daul&S#~eeCCU!Sd>#Re(wSrO_HGWF&@kP!D^L${=$~&_K3>!-+Uh(RB9h ztrn*Y;ee33|L)Ra_Dod3*~yseFMSIwg|9GpLuQExt;T=0*SG}-UdUeo)JdN~Zeb|Y zuOV!~JdMj@AWpg_V|^$A>2O4GZ_PN=OoPU9OutK zD`SARVw$s48)H-H#tt|nB^H^vfi2-5TBFjzc$R4@wkV}i81%Kn>P>=`xyMTc9Wz6t z-mUtARbPA>C6yAe52X)-##RnEawpk391p^Ar)*eL|8&fCV=46C0M7mjc)|ioEh)R2 zs{8LqyvvYD!}7^RL;O8L&AUa6&0Ae zqsh&UNp1Pam5ktJ7`qV$jRqGWF(IB(iYysU$lxPVo(e?1)%Q1mE?{Ew^3sVzI*{JW zK1pMbWTuLtJx~ZgRCN=qf9Rh-^3j676h2DfN(MHMax3Vj zk^p9|8XIA~8(}llraB`M(OUwd*`wel|5Rx{o$tWEVtDw+2GRG>^L}$%r4(3*!ogAY z-G2_s5(zS@b1XI;jCVQ(#!nu(mwab41gTp+U?VDv4fVFlOQ^#doK;rSt|SDPmi zN%95@$-IQTF@szfPiI47-jSEDz*WsdtMh>Nnw}3k=KcKN+ae+2$@zICrW&Xd4B-vx zMknA$W}Ug?CrmlVfq(giP${b)mqO^=TT{lN8*83PDprSBc5DUrpA(cqPH^IDB3$~f za%<@}UM8_9-AszZ*65jBz&Qh`-1?^UYaN~0q>>}>-!50E{P}OZRKHpT4;_1xkRBn+ zV4q3TI01L2gR3b>6btt^C@{g^fp;xMmfZd_eDu<$_#`|!Ybc_M^*?ruj zV%IY=F_fR#@uSWP&B(VXGW@dx{lC72Uue-Gp2P42!IcXhOeb(DhX)fQM!(1%f|C)I z(4l%EK34VSn(W!bsy+Zg-tX1F2%E4-D8#zLPs-Phs2|XC*sDs%UZz6nl<5KwMGvUW zweke+xe<+~&LMx;X;YHcV$&>_A((#&u0kl7`kNdUN+GA zY)Y9B=-?*IS){~7-} z?RSU_5QVd_J`)4@wNE*RS`w5$B@~KL{s&oW!3&0blBnf1zxDjfJ^#F(Yo&OtoAp=n zUT9w33guOHtE?fexNJI(=w6l`xv2I1<7uZaLc4bR&UlG=Rlm$O5%On#a;V@py)|}@ zN#L#Jk%jBvp}!R>bMZ!LW?gPxnCkf{&zoPP8fuH$KO5K$d@1^n=`vL1eofeq)_QyO z7WXIDm@%X%UXs}jz?QU)h?ffPvFN=HU*JufK-Ez5j@BDQ5HTCnC z-sGuQ55K;A>~QHw@lNRBtIx7zEPK4}z!RUXdy{c)TF8_~`s?GY<;{(OyWY^D`76w^ zM}8u0z?w0YFcQ*TlQbu3a_hzA^RFwvJ#L&#T4fSHdHuM9h{II#&1Y8ze;d#owUm?{ z*$(0;I`llWG@Db}#%$qNfMxjhs*3xU!dydJv=4T7$452nXJuBYg^AkQf~0~|Yu3pl zxz|j!w)DfXKMS4TW}i@7Unn9pWn%|lhTN%q+;7u!0e7IDxNWfSg`4C~?9TbaQ9~Sy8P~^6AP+4*HfeCcT+3;b*dsfo))#E)z92JFU$-)iQ+tHXLYz_w1_-6 zS)OT8e9GK?@X@_QIX&XazTT6!UR-?q-TOxlW;;0OU!ow^-Ge*zudqoTiKkh9nxG(}Mazf1iHJZs7dIa(lJn#I}k@rS8Yq{2~(6ulJ;9Ct1EJc2T_NGIB-^N&9=+v3lTUbiWnGuBen+w*cm zP7#N<*YoZ_aX6P^qkMfEEcKe`(v*Ha3;#-cV$L_2h$!oQ@4vg}$EWHZy)x6=|F!D& zOYeiKC%-Z#m7zi@2lLxtFVZk$a_UHCQ^PC|x4`UMd_GHNi zO)MO#IRYC;=l0DEbE1`VKKA&3x9YKza9Ua|Gap-ttj(qZ>ZkXoZ63dN54ZnGQ0cX3 zIq@t4FOz!_iY)=_<2Y9@Y}4a%RKnXuN*X_Q-Rz<_b{V((y){Hlu2e4Y21<_jG@|#0gm2lT&L1h5Hh;bHUF2zXWsh|0 zW}0l=3*2l!1aYQ@Igcc+N}GV@{_!IxT)$e)@3WdAgv*|qH3oGfS0`V!Gc>m+o4P)3 zzoI;N;8I)eq0@|s9CvJre(x*9E(x7|ahz%wsz|bTqvsuzry+8om=wE59d9!cJYD|#& z_1oBNKBav>XXY8Mt@IT+kc-@_FIZrge```+>?=79!zQJjb)(ui4`*7__l@IMT){Gs zN{2wbg0*}WHo;QMLY`d6t#>g?)3K-s-y?PYY0dFX=f}VBlXNFee6DaxOXMJ0H|S9u zB)+=4+`k)C0(8B?DghNi7EQ~|$Cw;GDhMJd=}?vPj9oq&>_ z{J?`)B^6s3VQY>*N(pbebQ(;*GA}ZgjSrP2nRg z7M{4zOIc5Dxp)}owgxya=IA`zeY<*SyP|v5Op~!dGAwTG14!iQ$ci-?<}~JgL#Lf5 zCo>$DZ4U)KV*ajp>T%yA>C+capN15@^W5%zo4-XQu$Di4zBugL8>!uZ^_g;MEMBrG z_E*?@FxEa&+v7w2E3MqO!8SKU2MX>b9{pV6)hK8?52z!S-27nw%*$4I-N#^ zdEE;rsADKY5X-QX)!60VZ7D~@Tp&h8*=Eu%HoihwPTWL{fxUH|m!CiU_v%k&DZBeQ zHAJO^#0;;wmorIhDeKQQf)}+*eqXWeDLPR&Z6Q)ftp5Bj7r^G1!a{s%Q zodMMaz|@Yhr^k43m}q4B@ZRxOuykq~_hG@5l_8W}&TlZ?cmQ{cPu z-rW7v({dBfp5cTog4NXGvy@AB{4;)CHVvW!}|4KY2^T!1vL|su#Up7M9Sro)|Gp%@VeOdZln;;C1*D zjjDC+MLUU_^#pN8I3)6CoX}r z4|y&to)Fx-|Mr@9Rt_-Sx3d>{e%?Ih+OEF&cBVS*zTB<*1qJYW-x^rqJ$9jrQ=-eC zm8gYbay@9~(^GCqm(s3{eH3v^*`(vN6nEg8zPFb5&QbQeyGN=Wa|)aTCdM|SPpFu^ z0&MnNXXROLpmdQce3iQoUM%X*kx=zV=$YXM6_qnXHiEy`t^T~lu6<7P%q)?8{Rpdy z_ox>OmWtH4vVv<#m{8AqCefY$!)VKSi01^K-@*8HPpz^y_pmcZ$a-(4(Bti!7ffzy zIlNBH9DM7_&~;a4D>Q?7PUW0o*kOyuy^q;H{X!TfrAL1Is;+G?kT2TY8UfkZh-}2N z1f>v1PFd~C;o9F1($2d0ApmhqH4;_MVch?$P@|qyr2M^Ra3UD98|QsB{lfEGbJDVd zl~%q}Z=`4xxq5JKo;`glW{gF?+Z5!PxC#&ss^qag;iDwn=sa#)}AUM7JdiL zU`D2gcm07u-wT=MCkd-lmL zKj`jP+>ag8df0{;th$ofdvPnTs59NfkcH=wtLB*>cH*lClozdto?2DL9h0(8rjKv! znS+A1`k{H6p)B9gJNaRu2Te+BYwt_Bl`wX_%97lEq3ZpZA~lgL7Ew01mti1M`yrMl z(CR=}21++u`(Y?6%p^^T!}USuJ+hz5BBk-H%2Np>^*cdSqNAhbNpEsSfim)LnCRqp zpd+WWGvv)@t9RZeOZz@s{>GvjxRMik_(G!AQzn|#)=r)S5D>d&MLf=@a$W7N*)46Qt-yw0v&O;N=M-fJBDqkryhVXirQ$(tR^3Xa8 zV2JKmN4F#dc%_T3BV@Ss`CxJQdnT%UGAG^#4pw{&bE+;5_QQ)&{I8IWvBtIS0Q@}+ zNs%ca43d2u8x6+OaU?yYP$f7jLwll|K05K_*r2C1f;N+p-x#04p=iV1Ht7u?+TH(d z&6sonrQ0G!dWJZs!7F}+yJQTpHv_ILJE%~&Ke^C^psT{RUm!P|V5ai8j2?UmxjDME z1=04jBz$43pk{~BnKzLDRC!BaQAi+qbl~D_aq>hrfh1EHSRzU3!3akJhQ=rF471xI zoRohxCGI1ypPRNVJZb$OhWQ&sZ(w|T6ePPNFIGh^Xz@nhHVlUdA7dCaAXb5&N^Iv& zTUvF4&NSL~8-&wGK>{6+vh-t-QJ-OJ7ldX=Z_sYFJ4wWpbo)ns`}*uQ>|&waS+cia zxR9{fA1^8frj3iOQo;X7O@2;y3kg$>Gvt8E6}pecD3$<(5%x-dE?C+GA?Ey}5D6I$ zynCSJ>`N2GCJz{*%f|ij?6x3y|GWjRS8IEsP!pd-bK& zm4UoCRL!Q~&$5|_Glf_N$xDG3*x?H9x29)_Lg~gtsF+X|@cpP+PL;8##D#;S*uS#! z3thxV>U{{oiSyt&tE(&zmVrxh1ISp%5cW8^Ie{G_10!*Wx0_lkUId6^2<<2xXm1^0 zL?8~5b}&?1!{_z^m#jUWO|(+;+!2r7hh^;y{C#kbT>K1b`D>KawC#wX<$QhkHLb8K zJ1VX6!Hf1l(oayEg^;pd+m*!*M2!OSqd%H*b3@RgCK`sk=k5S4z6&s#_R;8a7QtYT zaIw5F&1c-EC_fCSSO^i}dp*dvEnfM$NS%FAOgwVc4lzB<5fx#CpM?Zi&S{JXi?S0G zU_zr4!{0LqK>)k&-m|eq#t%a#B{M(P?1(zzdUVuCcr0=6(8a-L_dsT3)J7p8bTNva zCOB#p@i#LG1A5(L6YCKuN(=OZi{WcT<-4uLdAdLSdVLepzVDd7GW(AWexLS%pp z$ZoQMB=Cqj58<&bcOOyyB4lz{1cM(NcWX3H1IFUpy*l-vd+1vm8`G61jmPS-m?N}i zH2)w>dW<1!0MYzdz^G^EBf`V7L~A4Nhe(7daL^D*+W10cQ-*=sz~j?T^4LspvOw#* zgP2aL+0hIAY9PXGdJtHK*I;n<7G?;_Vf3Mx{qhUY>}Bdn0$2KCFT|z6fMSYGJzff`Nx>=&FVOuPsbq#OT>x;Kx9vVH%D%aWyN1|hq#43(|WAhL|DWo#`< zL=>{`JCS9Oy;2%P(~Uy55FxTv_I=;8@B8jK=Kg%Yzt`vY&-3r|dj7kY!F8Q;UgveZ zm*aiNa+qW7;cNYvY!bpPcYuHJ{a~~I+$|mo9&MQ6@g65kag1P!`DX}21WkYxD@yA+ z&#mB{-Xu=2z#QRuAjL|P?kIcL;QN%tVIRCmKfz&&nIq*9>fdKU$s;K6S5n^J zq!z_r!k?fPF9>>AyloAwFl=v?Cavsb|O z_7Y8jXv+mfX<|iRf~A37M+zI9Ll51OI$r+Mffxq@yEO!n+#vO}{5@vNst(OHI?>Qy zWq4F-!zfqJn^A@GkNbCR{$)dR-Ak$?8Rb^vyp>Moqnlql za7V@8`}5>lOi_6H2%DU9=7onjSLT1a&6S@&Q1=#c7*}2N7h=C_7kkxi@b!sCc6a$N zpX0y$j=U)+@G#|jth}4+nIOGd+$%Z1p%?`fFLezK580&CrR(vGb-4PltUqN8x!EvI zaO^A1vNF8E??TcZWsgmWN>l)+C-KbL8pr^F91{lzDL4iE~`kAVZ?iRaHZjw*HRf8j_p_wVJsOf z*l$3apurg*#0mSB(GFAkZSq{o6pE0_ZS9p+sl!dPOpB9K#?fsdJH`vV(>YutAxs!Nqp59rl*$_7E&iVrjPIhTBUxo=Vfo&o zT#?nW<6}a&Q%3sDAUixQI1$5%$0;=ac_;TXA2J=Uj58UD;Dc(Mn6Fo+QhvO(a}=H` z%Zz7yDPLFNyO{6MKKX09($uW?^~XHpNirGFZLe>RD`EWCV{X(ov!O$_I9~(6st>dI zBVg5(>FE<9$hk4)r8OV3!0x8;rLc^s(!AML=w)!W#pZpcQ3=cn_#<-Z^^L6=Q6KkH zPi`iB94CA;3;uq4RfLm~u-VXi@w0w!u4PEm9Erl5YA1f)zS?eAS7vKI zyy(;R$;fI35T2xEJNEaampb1awceFM*P>qXrbJKKq@~k-VsxLSb^Y&D^uhqE8zVNE zSI4!_LAUtKOK{Dn!9qOqC=2}1HpV~^AJL5%ER@Uu!Qyb-uB`s zp;L5JiG5_%H0Wg}nLRv(upR`zy39NT>k92z%g1;D-Ip53pS*$>ky?Dem9Wq2sUoi^ zhrCPyrSO|98`~Au_yn4#*EKk+zWo#PTkU0-23`!GCKxnL=E5zBrTdB`CdbhJ<#g|& z2;t1V1M$F@r6V^W54ddbmU4}zZzcjwjeMv+UPGgEvQ2(lMV7_X1WhKzY#>3+MNFxz zT@q=V$0K6;n9t$P-AXrx1XcTt=aN|N0AAgRz#_Zbnfk3XSlwk**6}LITZEju-nVF5 ztKVdbq%}SL#mnQcdi2e;H0xcK(YM;&^@?tKVLyCh!o@#oS7mAwTe`~&UYCh)s;b<` zyr26Uw>?bLl`?08i@tE={%x>5qsXvrgrRh-Ppxn_zti<${GN4ZSf{JQLEq+ny2G7R zv(O_vW76AKQJIhS7Alf&8i;pFAC|X@mcI9Q8nc^g8@gX$IeKW6zpCmwb0)s8?7*uu zcP`1GIs2*WN0*uY8k%E)igS@!3tPj!MLH_#GX233>`y*!ZeRAH-rL=nHNQS&__Qw3 zneeqVKPypphsj`fNEiL$O@C@Pvq6oXmSe_aG5M>%J{jI>e4>?TV{q6rHtFmU_s8&| z_*rS=;;aMPtndcs#lc^{Ikam7#>~F>Z07wDv$iN*D1TSBIqJ>wjW<1oG~|< z+M0NVwyFFc$)LS~hT>)OKQex&;CoZU{*y{#il zz1p55+}^OCf<4y1R$(La`lwvd?zgVuola?|)rN(6u(k5BOH2sMtuL!&nOAU0JJxHw zyKG+GX?3d`&*6b&!ll=)(ZpjyLlPE|$B}pq9A7)& zO)^M|Y8sv6xO9gRnVrZEJcl+%WctXNst#^2aX=>(MI|KP3>dt_^rFv%%F;ZQ3m6OR zW90D3MvFvP(5KBa`Kv#r4uQgQNR>&$^rW;{-u7F+e%7z2IszM#uMX2Npvb?AKO5%b z*WmRW4PvndU$bi!4H`3An+zG-rB?QNP_4~Q2I%ty1g()aE_rh?o|%HZm5swLm5cfX z$p3PBDx~d?3T33-(Y*T{J%{(f`r*5(pPt`OPh3eo*>4>6%&e(Fj8%qL_oqUDd@6;B z3R7J;t*I)7s3<}*&Bzy1v+B`H^Dy`6Z zh-viIKl# ze>Id}@NVntWuwx;No@L8JCAqo_Ks6=MWeOF%2J11H`m>BN)2`9t`!~5e=u^}+;zLj zjB|Sw(*Z)1(SdO`qf3wt13E|`eNR=9qL2>}rH9y0Vx&!73nxUgjq7FNBct=MvRV)) zlL?&*P0YFln2b&OGO_BWC6}uG(9NPKWNc12;WGFkQ7dpC~pcLb{bo1&r1)g%rJk@7UX`N2`3}X z(igYG^O;>Hr&S^9N1WbyPDbb^pVowjLsc^QeCWu6_NH_*Ejq3l++4;bnW9}8k1?X=pV~WL$98bk zr_NrgW-cqQxBb5RPtF*{=<{Q<8}?#W_8S#xn|?M2f6kcg9j)f~-*>S`hS&UI=^Y6< z(Vvc1>+hh{%kmETxr#z7l|)5xKAO5ES9f4Bds1(q>D^LZtq7aJpF)xO8gsPlaW$I; z2Q(MtYmto{*6o&9R7QscLx7Q$K7vg2+{JklmY_G{PXY$7rhf@%bt%m%EmPQXVZY>2 zpyM{)UXuQKWQmx7Y^CX59rxJWF_&>5`>FJSN@h` z-zpBC7IvyE$bNG9;GM!rr~AbL?zQIG&5iC|G?!AXhrbw_lP9qmcN=*wxLQ8p7Tf7z z>->(gyFtO0ir<{CpUT`aEK6*W{jysz;~QiN8!siz3pU=N3eQ{IaVa#Whj=%pKWhA} zHSuaI#^Kpi4DRQ@BYGfm`&tv?}6hdfi zZ|=MCD-*;^qC8)`5g(w5jU1FMk>r_G!v}mn`s=mmkvlOc;JQ&0u}~$;a=-n~&DH8( z5@?Qb+76%l%wFs+1kF%!r4Hs|cts(+O%b)T&nPKk+Is@NO)n8fM4tR&ip3`NiA z;u!yQ!mYmJjyc)(s|hoCgXJDg(-+DnEN%NOy~GG#P#P2?L*;P#o5@ld5w^Q~n3?2U;B`{8SN>vz6>BI|a`(I^88%Yi>8S%E8h9eN4dI=ShiA%Ou}K(_{qm-- zvC_M;C2)11f!n$<0r3@ywNrtYfk_OB+H zLqg72b#kiud9cqBzc%{qNU?~V?Qwe_p89?J_et0F74@l5#)pXEJe~7sp7gV4F^#!M|7Jja?7nu zFg`dpm;COiK;=BgQ?V;+%1IdEvb!_i>nl_Lo*bNuynK|~OA{;duBf=^^;e6YHdXej zl}w!}Wi+mDsI~UY&J>!o&MqQG-z8*p^PW`5|JfX{YR-v?+uuyoT3JiAA4qp7G&wi{ zYUnG2C1d|`zjp~sb{)jDg^u58C4RckSB;x}`&Ivp^)4N9g~nC2w3nfH$c}iUZ>0V3 z(+W8{q#I@Ncc+Q2RUf3Y)W=$B^rv=dh{#d%ugTnc8OR`jD?eOi|Ln0_;X?a;<>`EZ z`M|JG! z@{6c_z({I|l+EPCy-1oFGor!Yx_teWM>}DRs5{u~U4^c0nu6V_kYsMvBbM74e`^J7 zB3an9)7s2`lod{YRSk~jD{0R)ph9ETtQeEuSs2yXKU^vDYJ^+z z!k$r^iTYbZ?Q)Hf6Fby9VWa>3OnB`tzasj9U;el+HEv~EmRYTqPnoq9|LXc? zXzq5_QvQ2y81sFER&v_wFJ`Zrv#p6SK~{%s$*8{%)#pBNUJ`6{wlu5UrC-u$A8x3* zL>Rx~-9mkrxG>qh|8&l|*|2bXqXtJ9=-XjtIsRV$aCuyM*}ranCSL9S$8)a+&c_P# zqg}#_jOaWqa*$HXL5JU!UR(~{S&~`#!XWslqfEu^o!r{+t_oX*)s{kEKh5iX_wur7 zPOCc=Z@-r1XASlop4cjs7<$x|W2y3CWI(-W+G(cyx^*q_tjUzyR+d};#ERa7gHtX7 zQn)WO^OL3e?jmb0bO~~TzJz^+nu^2K&tGjlorb~@8Xvs&-|}V)3R`G3`^sl|)h&8d z)u=z%zkWKQ!)CU2f|5Mc{bWkD=A?qx<8RbN0e@=;+BSyl7YtA^hbk6c31JY%rOIS{ zzxio^_wVX~ql#{teamUt?H_m-w-+i7k~4T}&vGZr!+X7sk3W>_M2VSB&>8G;!l>~dfzpIS zefQuYb*2zcF5O!1aC6_DyWg=AcePluboh#G7d-p(Jo3YXDJoVJ2=XZhK1YMHpIsfL zw#v?Ze|!OXbBrI!)EPpJSL7nyLje(<#kbda43bn-0rHl$%~d>H8d_8he8#wM{PV2s zt#R6TX=%^4^eeno1|QtLUPkHi_X%#qbC8|h@T~WlYRoKp(r4s#bT_@`(x+<8D;>^J z@T;+8y<mtwFFj4@jQ>rQ!tyPX2$r0EHMz znxNDstRh9&OVUO1`SbG|QXc3%}a!UFoB*GsB|I*AG8iyK`X1j0`uABHvx&9Qgo|&|s zW`8^_xKPpI|M@$y;iY$-USnFh5XX@kLZq*ffmSs!^cgjiFzGSZBrweswi;V1!}{E0 zR(^t4EJV~Zr}6Z8lIE2bt~JT2mo^+6rX~C4m2LK&|HsNjXOnRoS`Mt^4P8S$j?Avc zR5y#mu3$ii4ztfn9MHhU9I^dUM%D|Fsm?HYuhg(?b{W8AtYEtFCz4JCR(b5+k`EG9G}F8jAA83VfaYf?f4 zHIbmGBTET1Yc7ne?u10RGh5#s?m*|auCYI3_`f5X=fAHt<%78JEpFA>8U}rSSjcZy z*o3)JJ^oZ&ZxD|rc<4pkv;(tU$&8@aX0N$T(V%PoRFYSdOejn{F6Q)EZ|AtZ2Y7+y^4*M1RAWA+f!Y z=cftz*Vq4nc!fgm9mmi2^8d+?!dC~e7+EBHuZBYC_8-ulY(YY;h+vhd3O$2sfoyHt zev=T+RYObM4D^WDO%a@MLoU!+`IU;x{l{ce;+*bt4GGgshz#pFYfdnMl!AOv)cTXW z7A>JauQWei+N6r&@+8>wt?xc?2%01B2X0_0RUK53JW%5rfVMX$vF)&zeOjDLbuV+gMlyS*M#s{4M7&{NXD*^1Lc#KFJr0KRBz>- zfG&!Em(_uTpf$KS%XV!2e%sQ&ak2oAx%eeK=z z?#U`M(tPke8z6i2m?+>d3dpNP_;a8D=05x2UX4`m2}LJa0M}ax6yZIPrffi6z-2xF zVPdoM?n9X{&tWAH2l1*~~CDXO@_FWnE-2|&_hsdo+xy5OMHK7gW-&V>* zN*>-Ex_V%r>IdVITXg_K^ zZa<#B=0;ct9kdS`Smr@ra&uBXMS#rYoyAC}X8CTJ5~a~~XgT{PR0m-cqZ9)$hc^qf zN!te;PXlR`sgczKq#4wgFK^ScivMUNpJMF?J$ansOw%?@(X)o{E3pg?C)~cr)Z2|8 znsu)MIsK+ddM#ogTZ&ZM=Q>_5%j1erQc%4hgu_b*nyLn70M^GXZvg2^^9N!Cdk;(& zF^Aaqv)x~oy546CH1m&?Ax2THhTQOUm(( zTZZNgx~Fjn>+RrEEXDJ*%9A%~^aq#maG_;;ZRqWf9s-L>XPcdb96q44(V=nGxwo|5 zs=(pu)OGE3n4%g2Kfz>=I)qJEBE#WW&+EEeNj24IKH2cLDWhg*4ZVnoz-$o?Lm1IS zN6PjR`)vdohH<%~$MI)tD=TYfAR)j$-Bfc9g?grkOC)PLhh8A~^5N|FxjmT4n5|5w}Yg3i}B*ST@%_ zOEkN@XMvPuTlGe`_H$mO0L2;QXTjoCPdFTsyS~Eow^BXt>Jta;h-3SN!yI;IYR#)$ z%5>=D^%4-#dVvipAM_3wqcA~+c1H86a9p?JM!gcE+7PE!sTK4Lm|l5d`n9t;zMj(c z2)05tH%hS~sxM_^py&EN{c-HiC)Mh{V}TYNPKuK1_sJ<9LVube8;XFyG{m#LgvfbW z|Iqr6xKQfX_xEALkC~^*sq%EfKk1f{u|1^7d!fjm_N?6uw`PfX3uox>cQU9zMjyvb zAGy8xd_(?#g3R|A-qmN9xFDwQ<$QaDJ-FJ0e+YUqFrf#u3Iy*skcQH`g+=P!_>dIf zd=9Ai70f>=bm%kcJ>iUtyt4t%6rUkeehj&Xx4_hAOV9(N7*A4j97eRDWD$J`RmoI- z=*JpZ?5<2;)E&l@g`5cpQp1ul^ zlqsJ(6GKbEIim$4aeg~auCdn{^vfFS^9Q_bIi2C?w&Ok}oMZf$%Uv|c*)6N(tr!$~ zUbvVSW{_1vt3V+mx-v-d$LhBcCwu7XFY)adL-RN=T$LLXCd&|n>SXpz%Qhi5XZ2v)d}pP`H#7RiSjAY1q4)^W{c zkDDdzm5pC7I7W5xj)SsI9VN?yk@-4X+`2;~5=|6GTE4e^UMtm+HiBi$SRSkHmPK^( zK2>wvXNu z)#~~P_m1~12(Atf?H`=ZBgn^CGLRj7gH*Jz0tJ*cG?r#D`Z1AejiFj)Yh!^Wh}0X@ zV#Rq+8vLDw<|AtEwEe(W=el`U%KD&QlgurkM{_Q9nP7$Q(Eib| zWWTS`sKhkVF>XEnY;1bf4RvW^U;QZ)#woVxm-KjyVDQ=tjMBPxzP}GO!=z%xi$9^I zENJ3YojCveh-Fj|G_a5JOm*9*Yv!T?nEJf^rF`tia32(&J4>hP5L->yQ=c*Az=OP#bts6>)R3bvS z^s>~i)R6&v4o@+n4tsx_ z6Zy&bKFe5qFrK%~U*m%UZ6%W>UHZ|=7{w(1@g^;@aG%2H6t|oqgECLjl(ww$twv4R z>@e$V(omL|G2S^sMVnHLxzO6T;RT~t1NClTkMSz0CD25U*^8qTK%uX+_x>f`rV1U>WU{Y{k-bC4Oy1tq)>G*RLzd zBB66uwg>9Ed0+Kerbx+ghMO=|G$Y)I=MkiCOjsLp4X6h`pH}y!n@VD%uD)-v=5nF( zZu=vaCdu&@A(T21DX+90bEx_u+h1n~1-87nwPG z3^pqw`R;Uo2`U$L`}{2eVo=GML3ljkagM+>r7%tu^*httjnr^aN3rA7H!V$ViL)`v zX9(|-sAXLOFD!Twn;-+r5Kff6c|X({*uttQFr%y zbBgo3FV&hQS@?NjS|~_WmF}7173q;eFZ6zVHgWhQlSQ9ZGeYDDxm-|qT%950l@Kl2 zadbkF39sV*jQqA{@`~LzS|1-VBpXMa#zPML`;_eGdX8-kIbsE;aTN|mnl@FVr@Pag zOitk%Jq#FKgoyWptcbip;jgRnUKv-tYtj&_+BD|#e5S3`LTRvUFz;hV%kHzFd12g< z&a~an_lrhk(CumWOr+Oy4V<_@4vA=F3LMJme;?WGWF7ZGh}>%R%9K3m zYiN?BLdh-S;if_0kIM4IJ%pQ@DV5HmA{@{0p@##HLX2EIfGN(s+*| zR)Yli_))sufbKK}{{!gw?U-DWU&E`d&GDYu%YUgD2il8nlxE)mEd%y^6OuqZf%4CG z3GSoscmGw9Kn-}6DVB6KLDld3;c9d>6q{3+s=8ZH(L8XyG-MgJSe|a0dg1yF&i6av zinBD+GB08M->zkOcrRk5Msi!B$W)K;}K$HL2!KPx(ClHQD7DS2L;oS zaB3y0o&qVv538U|*0BJ8FpOYvn)IVo@SoQ8sqXQT1p%QICQD5fR`h!_<1Txz~zde_L~ z<3h+vPBSd-4FRv)_(gdA+Mmnd(zVHzfi7eIl3>%Yd$xm>~MGP z4)_uA@ngnDj=uy70XFUdsQ%#cHg-59TO~E(Su@N4wT6E5JE>TjZlCElFub-90AoLZ zv`oAej2qM_x}4;_N;03_Z+OD&ywIIxU=Coi!k`HJTy$?i z%PzP4gFG>A_~Z5X95S*g<`7pvjZxPLH{M>Xu`S!U1L4%St6pQ69aj6SzVt$hkU3cX z@(f*j1x!rQc^_I_G|Yza7sZD3pF>(#@uEuIJfv|k5+KW9w|UujAxdWQO>woRQn&B& zNF@nw^+HbF6Y5uByZZ7&!}i->W~kMgPbsMq?O@<2z{{v!FiO}0Jk(6mIN9wWIRBQ6 zTyumj)R$hMMbbq8-pYLN?XfD&T!h`RPg6a|M9N?3O4H(Id5l>oXrFJATPkpBlBsLD zdgKiybW-^1TVIR@xqP=xNI^&P8y$KiIP|UEHSi>uE0cXovJ6YRvdcE_qj3ORo{%u_ z!ndg;*I}Qn-pBEVPoTlO0J6v@6J5phbAE7nOd!3Wb~ii%rLYht!}h?S?*~kjA@1OP zT*SVcLtNYW3YqHl_Oh@otymUFb9a2;2by;1ySX3DR-F`cZnA(uJA~(+FleeDyrBgB zZ==EN^N>CpnUGpN06P0yyZ^?!+tG{F`q(#3nDJ9}<`54T*$gORasH!(TkSsNRxE^_ z&$jn5e)lj0kW25!2I!V=GV+`en0hmwUi7r{_mJO_Pbn->7TA}hpER%Y~m4$ogL->;)6T9i1$| zcFb4+81_k3*I|dA)>jhpdQB>g4FPC2NpC)G5yhU`V`Hx4^2+S*ClVlyg~p47-5N|M zD0{*Gm<7%~iG{NY?3q{h9xgm2TTuw9)`CaghjdId->+vt)?1LR$Wj8IP;fizey2=v zv-}o;gsSb=A#HU$@HqVd35S9`+G&hD#A&)ErDluKm!NTI$g3$TH|Ax$<;7a?VH~8dt8sJty`k8t-h9(5yu}3C-%J zz+iwS)=Q@;D5@xU1E~XMAYJe&pr|j2?qi|S#y^f-KFwiay(X-lNgjCwyp}o1d(7wp zLuv}(lYODbQCE8$J9N}#DxZ^Usj#kirE)E+&hrESVYH}!nq9iVb-+QLMGCTLAHIN| zm;FAz6Vs6G@c;a?lP>*2yTTL@#(r`StFGCNFc855`#OSZ)r#amQ97G=lh-6B< zu(YRJ)J)}pe9J<}uIPscV780;4)I(tBYNEd=pVTlO(wC|@@p{uS`O*^TR5%Q3(-wf zoOBG|l6m>3LDCJWaT(qQ|0$L)z61z%%w(AaYwvLati@fSU=JWdz6Csh5_V9BT_3G< zh*YMQDO<_8eP;m~><#f_O8t)|I%P}#wS{jG9Z-i$QkF+X&g zMp)%V&`;KEB=U!FokZ;~bl9_>cueI=R0nsHs7^Q>5<@h>;QsAB|KfJ=0GNE{WU93#fj#J~^&sKHo zmaBABgk-0_sKFe}7@lvnRwbQR);{T-l+_c+^(-!pI+WR{6YZi&OAsPI_#5AK#&V$y zGwPdqN79|)l@b&GqXi^urtnw6j!Hz(BR_7}9upjaVr$~&<-rLo*SsZOtmiLyQ5~jJ zLQ1aY&Y#^4;vUdAIlgpx*gKzbH_7S2JM?qASj?}Hnz$L%vL77P1!sVBhGIHaPm%3p z+PWl)o$!Q;eMSr%!Om%|m|MiVH7K0@ki)t~%2=SD6~TrLC#f8iuGpfaN zUo{z7h+o_zwE|SZXmvY27Hr`5r4Ols$ZC4Tw*l<0whb6ZvGSEDnE|=Wr>yCqLfv{b zCe~PmX+>@co4=?~CX#@mDcUq5Ku9aU7N#;P9T^i=l`WrfkNUx;VCo8xa?Wdf0a#j| zF@6gc)ivUDWgzpu53&9i^d2njqBC^hjV zlz7)xtn88XTqMgKt?!P?PbIo?BNme%!gBO7^z-0|CXgHIJ1EaGZwN~gu)6-!adBR$ zOp*lqRDM{hTY@NZVpwZVqU0sr8f8xF3QtTB%L9A=($_mW>XnwVV}VTnXdEE_g|tTY zZCbZd;NVzXFeL=kwIdjS9xS?k9`)4!{SSpy4#K{v7~LJLUTq}j9O04fs0F;%3Z2S@ zd$y1lxJ^=$0zIkN=G108-94%&a7?b24t-0SWB#8D2m*Tk@7Hk?N&d(WM z^4hrLf&{>P zyWTph1xBuC$_}*Y!q#(t^|-DTrH z@&n=TAHUE*$&Pbhl(<2Hy{86e8K5FT@k0?Cm?H%>{H1@{gS!Ob>f;fNz7&*}`WOc$ zYDPmv*62y)1bUok6UPP{fw>(ck6s;tymIwMb%o;7>Tjy7pBey%@U9A3=Yb`R5TO?O zmAx6zaRK+;Lyk4#7Men(V5v!wU)xyg12t#Jt4EfoQ!%?g)t*}8{|EG@p6RLc!FPKFafu^^$-X`yCx8guf&YA_}weWSlS_IQ|CNIC5IqpLt)<8uwoK zTy0Eb*K{1CAFlCv1QTm5t-PVycZ7)&YfCznl6Z7j;Pa*onN*<%d_kwu-gXFwRX9ki zRD8TrP+@0M33{-NxVOh;0Hcet1R4eV)QyTfp{wRh;fKqhJ3U54DvTyeUx6H3%7o2RAY?gR_* zsfu(l+cb-->yC6BUiGiK5>l-LWsN+ktf`Z&r$yGgll3k#jWkeI#G{IhzkdvC-@sn3 zrkybwW&6q$y05l}?N6FJKsJsE9LJI)3xR|Y&}Hh3Rp;qXR{6!3mZbWBmq?DZ7d>~( zyaQ;ouUD~H#g?e^zbFtHVoiJffXtI#ApVOdi+?H)Y(}P+Qo+;{#{lTbtRSe_G53P>on2>_sL1+0m0(zpxoj6WXm;IF@*hg9DS!R<3qK91le zlTlQ}DbZiZ`$ve(5(ijS@*&S#G?@e<8WVavR+MBc#lZ)h)k>7%YQ9mdb6gXBxFJfW zssw;jNXQCQ&f{c`w2sCxQmRM|C%DItJ7mAEAn_#Re=6EJ6o9fbd?58w+Zni;Ujl$< z9d|ytA%P$HKP(C&S}&SBJX1JKF@c0D8(B0aLlsZKEBDbe;&piGrZT3xx(Ik}^tyd` zn0_Aq1C~dVqN!r`);qb)ME;FzTL9W&0q_TPw7XbQ(c~1)kKX>j|GsEFL}E-fNrM%I zWsm)Hw=7&y2ku#z_BR^{zR#W!0{(+U+X%2`mgR~UwJ09J*Tz83Z-@f0lL#M}e&HMh zcjri2_6Oy&hX3~^OfXISKv{*Mq9;jfwniE20mLN1JZD!?I7T z-d;j!|CfNo!;{8C#e7eX$zKu;YnC7u-N5M23SV3GG0k}RGO=j*fGnROFBvNO|Ff-S zR;srQig7Q1I}T1PL-saAfZISTMLtxv5;SG#FW29gIQ`R`cm)+Y;{ig~CTNNV(00Ml zZgSjH^Y0W7^QHU)8K#ruhJ0jiaj@&#?`m&`D<&W`hAkUVvH6^z^!Zv>4>}@YP-$RaA!Lktf!aI|vz`hJ-3B{1 z%0TKYg#oe)Sw%y6R@vWbp+|MvSqbJ9YqI(K6!1an`! z17m0*qb_-63n-C(kYTDeSALM6?t$A$LhSW20ymHkYbP%pume!L5CB_=gmOQZ*ZgQL zi9k`FjGe|YfSz1nGfTo@((t(CI<(z}he(H-|F>{YiOXk_wryKLqLJ-qbq52+CP;J$ zWYD2i3H#gQQMQYuMl??M5r5l0QbFb9$u4POeAm$B6(cXO3CRcW@y%c8XK z<$~mWhE-1%9|z;1WYQ9+wLh^v+J@pL2WFlAwW<;+8hAwNy^P?HY3oUY0P-S0lMSR! z)*|M0rCZN~P5tV5L-fJOJ5b>;3mNTOfj4u(8Q|Kci9RbWP(qs7<@)0|l|BGKFs;zt zJpkVWRsf!=Gcls@EaHO{FN{t6^UgMpq$0ys;?BL!@r0w4qyR6106mo+hx-bH*hs?W z0?AV#mnc}G3&SqFAWVK9=Ba%hbnR~h#aV_V&d)>b+8k~@azkOYp06p?x{Vm>ve4J5 zp*4EzD45+g2S7v*WDd@Ac;jAJLP)kf$~Z02H4iBKUXY!az!d8M`dvjabu;`TJ4iVv z%E{MNW|qa=BqK8_EvH^4O&{Aw;Gm1s#R{@O+1^3haiZsJ#wK9G(t=C77C;V&JScA2SR>8j zlKLa&7@Ld^eWaa>{Y|p5gREQ+`>i?1 z6Qpfs11abwEiJ{L$xHLJAlP4ozcGFg*j^pVm{>nD!9m4F#hL^R*5`M3w-yHzpZDdP zYh4zMyx?b_`gRS*Nf-{}FSuVka<9=O?nKK5<*YtHyv4p@leV5FER#GlGHuk-lL36q zzNGeUsL4$t1=iGt)SIc$Rd=<8a0)*MqikHlUS_L#Q` zR44ibMC@~!+`yhd)_f;*b!-Y*O$3pR5e`T@*UntwN{?6rw}jeyumb*=%XCX+Z^O`t zQ^=d3vunH;f0};;qmga%@)C6C_QTw0e%$%rBR!F!Q7c6$mHEl0t?i?iq7f;Ihq zpAk!$Hl%fB1)GH7rByN`OL@lkD?>c34zq~83hm^hxFXGr%CkQ;1yV;7bqO8>dEL7R zG^QOIRmz=iIW)zK7)+u#JD9icpW?9Q*w7dY^m5pR4suIq@^*~85-HnC>8r+<`cqgc z(uCh3;5p}wxz!jx6uA*$BzpWt8*sdTtGOPX@UMRG<}xZ$@`D~urbgxmA$|N7O4I`T z7{+8=N0apzSdMdO_wWg&yO=CueTz>hX|7L@%ssIGni3l>Q4O=iu+7>-qv<%|F-ikV zhtMJ+Rwv)_Gw7JfpIj!kr2eu^WTK3~#D#bU}7YJ#E|d`d2Ah-m%WaEm#c6cPXZ+WI zxh!!&{_=(i%gQfWw!L{UZ3#>(~+K0qGR6+a&GJOP}6WY>w1%4kI#>~I|M|6Gbd;hW)@7+q=Il6$$@Sv zHwvJpQ0>IC0zIgJQ-~0vqUXi!fm%>*I1$)xgX>Tsb;6~J68l>qHf{Utni}iKI!Y8X z_^`w|gV6>1&fXEJYQMF%a>4snsNb5GK2MAN8;5^X`r&Y&<}k#bv{F@>uB4Lt{!EG3 z;{8WT@ulWS zeo{N|+|qErU;WVdMH&8H7hju#zl8Fo0$yJq=W6v(JcZg*;IExyPH&!5j=e8PadPg( z@@NpPBuh$eZFTrLx7dFe&V)^+1uDH*rLAsw9PwX?I# zEJ^qaBVuIP5iZ0(qXT>m_Pged34g2UBRsUden?@uh+#V(9uCTbH1uNVbwYN#zL;*@ z7!r-C%5*B;;lW}4nnxJdxf9Jd76nOdg_)(^E_yw)B|eZN>fxW#tuzZg}j9h$UF0v|vE#NH^GHK1_BE1I;IAULr% z>|Z(v`TTr)9zu8J=<0d}+_!H-5}pvWIoyVl0uhzGlg?++Y-r}XH(Bk>(H5@M`5M?IJDWZ6@0 z^QFdo=W6;|Sac{;d9Wg^jE8uG;i^zZ))uk3iIcrZbj&7AjC3j*TV9+brG&`~@j0Ll z@VCO#n-~;#AMF-k11q8ew69hb_uPC*_YW~m#}9f~tCE~Whfb(t(2idribXuCJcGhN>+3K+ei*MCe4g#hYOxD5kU&`uJa#_Poi-MR=j9+w3f@d z7?Qe33;#Rku;+%#Jnwr%x)Fglwuc>tn=*O{7VJ16OAl?#jGkQ#K6=0tc&}(=ixBOp zRvmQWYPNaeb1LeY^(*0jWnSVoe=`FBJz@YA`CkM~?QbSWsbJ2XhBQyS+dbypX0PTz z%#c`-;;pn3*h&pxBgaKOM-S}WOe)v#S(O6vA*fxmI>7lNZJ5d&RZh^iK(DNe=snCD z5t`dBa8z4l$PGyc(#LS`J|4nLT|MCM-LhDxdKee9<}n&iupn6T-9VJi&{{DNc*&jZ9IWOudo27^go)?uT_) zj&Nz+q(Ms0s%fk!-`WY*FpD%&tbMAoEW62qgtjWI0AFMLeNF`@d@lAiu%HYu-5r1VKOrfr&Yq4~W-9b+8_GG$J7VQGeJRN& zAR4t0CChRY5L(q2FS3Q+lX)pZkUWjlFb8d&xHelajjV^G$dwi&AZQcBxZQI9*i(9h z;OOdtUQl$LTPhtB7fWG`uB!Yo^xo4+-Ef2r(Vp@H%}hgZfg<+wOaPIv!`!UE>k{4} zOFP?Hh>o70igphNawi#{tJSks6i7JSfslFhR!S48VMQWjHdgB~(Cz^L|Hd2V7@#a9 zNrVitZ3$?T*qcB;{KpOC4G8m>c_tFegC>zmp^v-xRe7>V>_*ogzHuy%Ac<5mX;X!h zv=u&pSJS8<6>ric0_CRuLAc^7I28DkbJY<31QL4z17qWu{0m49hWlbsFBtvbkl2er zO%@XBd@g5ifvh5`iB+@1uVr0L-r4Z{~%cex(7ytwdLI253U5Yl6i@~z&%0(Q- zBNSjoMnOY5*N5YVp^KBO3>V70Iw5YfkmSJED6WgCf~}kl)f4g z<}D%ho|=e|j78r523wU(PDqVv#mUa_f7VV>E-%V~R_s;HAt`lEmzRbvowrFY2zIYhm?Mr{ zimK(kWAH6fUZ=v>=QHG7DcOe4^no4!l71rWPmWPHIs ze3?F_ow}gw2m3ge5pZPUt8Av@$ode)R6+m*jP#S(AJWutr^(HJyNZ`iTVq53w60Au zQ^i6xBx6ASF7N{S!{t5;ux@UtqW~|%EtY^Mu>jLV6lvg_>xe&9t<)Dtgy{jc=_vpu z$E^wgfVc}jAlPlf;>!95v=@^wYB0$u5981K0B?>nU^6LOBOuqP=CK_LYGP@f&!7gQ z2ZHtGU;<#O%S_v2IUUx;A#j-Vk-QP4{3o!uYi$Df>VL2FJQF}m(;-QO9V|5a#qim% zv~W+%N;TX~D=`20fdh$4@W}&Bs!q1?O*Y@fs`AgbGTQt$U&6^i?Ls;k6i7O!{C{*# zW(UXio0so_&E|m_zpflOG!jkgT>jTnu-F#Dh8TSb57e!l7SGSP*0zf$W3PTcQ zsg~xxRJfXsxEr~T25pkq7nYQfYW7w~@N4wLrnCa;>=Ff?bX)efNw%#ih|l;65}1t6 z8t^r8G5+n>1F4l9{)iN@#?VfNjI6|bAoXXI!i6e*HYzL0GAiqXvEs%g%4aS6PFBKK zaEd4P@ZkbdBN?r;YjlRd5N`uTWTe0qs4z!sWv1UhCmfp@EG;s{=QrGMHPtOR_J1rjEF9S|9au~3Gmix+BA@rljK^H?}7G@v7VJi`A zCnY{V%g5N2%I)0>8#^ z896rf1|@1=MGAZKg4PC75#{Hmm>{HbW2>=xEfm*_@%9Ye02*rp;Wzr$s%O){Vv~~J@0l@VA($K zAT)9v3@0+Xm!34_kWhW&pj?9yfrphfyxgJ6K{wfb?Em5Hy`!26qHke|0TOEH(jg$R zpdihF(nAvsiX8X!s7NnLQ@S9%id1>$>aTw5y??&7 zbXiF5otZl`_nbX@?}MF)a8Y&vOSA&Fxo^FpysswS!d~#b*FPFzvAR54S7bRz-XBp$ zcQ~T8K@}C_DI!uDT&d@9b5 zt-7-Sg8!8P*mi??*~ChmVK7nsdhN_eID(ovS9s&&tJGAP`u(L4;95(O&iPFi*h zsDReq%a$qTn0&B7d+X#-zNqa5TSthP!3aj1+YNF!-rpwYl$@cMP{g!2;7lNO152}PUk$*ZlQ!HGZUL93Bf*;5fxGB&=^F3PWToDOQZ3}R z;~>I1d`nw@7X~vbH~J`w%LRqO+8ajSKJOJvp7>LlYBpbCPxsq4k?E6zsl4 z-}#mzJ_}FTvup?Ony&bLXKEFHERfDQYLJ!qOv5=`0~tpU`OxjCN!V@eNy(RX^WmA= zgEdc8AxfWcI#F$_*Td$!Dp>;Mj_#@6Gx6|K*nD_pB5ON)AXm;wIS)|*Kxinfy36Sf zU-#sHYEX?i_=>xv(Uyqh{YSd#@ZyHlnM7R;E z7EeCVho~=Hun@90v3W11u@MVy>qU}^;>fED?^#2*I%~EYjYM796e1c2hD^|hO#=rz zan6~P8vxQcia)QNXlnJ_IYAX)_j@s?B@zy?mD%Em4Nzwvpur}EdPcyDFKdigB*!oi zK-f4k>Ee=vpW7hlJY?n8t1_j#gqQ zFrxDyDBzCB3D^oa-ui{2(`o&!lNoT;OTeJx0*3U zXI{6Zt~;@_@K;%=e2Ku&5GHE0UHv(31HFhu61RGNqvEnGc|G)2`_L?{xsR8;p%G8H zSaV24fYn~(sG>t_(ubExZ5upjoSjIl?@eLr?+gdYmv*nZO55xK1xZ@%AJf5ttX_W7 z7LEX0Ud#c*aSi^Ah-~`JV9Xxd*lS|StT8kP2Nvq9SQW=`$A$wriZ`#a*^!!@r6K%p z!aeKJ6C2S}%7oKH6>s@?j2RE0kuydH&Dv-~Z3GKTA2&~Gv8T7 zOq4Ku5B8Duoy0Gc^3zTu^%&KO$QzLPuHInHS|VwzN#!Ut(bk}D)?xRnQj6mlNbk$) z)SqexYIu!n-+_|p>FVdEK^55;9u<9pw7vUg4Q8UJwK1H7>Po$W=auUiDwP=cgmq{O zlM~+f7zM`>U0M{1z!9C*JyAt@*qWG+p#(h7-%+QSNekKPuTZ$jR1tj7yUZa`a_C)A z)ZcM;I3mV?z8xqUsq8yMVp_Z`hWcc=uj)e|n;?m5NMEmX;~^**_@J6!DpL}CqVd%GYXP{wd7gxdy5@nw4`8YScF<6&)hey z4^z~^hER=I>vji?5!0zt&PEX!^}9pmF*@_MhPoGrCy4zpGGwk8CMsC-lqH-Kz)KS! z5lb}l9Ew)LP8QpzkT`9!?rWX2)tA1^133VL_ku==eU7R#RQ)*;8(5t4Y& zh%NQv_=ywztA^LAa}){N#Fu~Uj7bg@V$55uz);1Q)u5XZXF;bTe=w#F`-?4;DoU6A z`*AQ$KO3Q|MC5MvWc9P_?`fU|Q}*e-t`XP-@L(Dzw4zfKWreQw&1T!Acx#Md>@fn= zCEe-Z<=L~dx_Njm+#k;55C!}Rig3p#%AfpT%z5GFg>I0nne3|GqtK%lUS#l^pG~7K zlENs-*6efd3_q!eo$Z9?SZqd6WO`&J&Z}qY>~EZVSNB@Nnb$m_Oolhd>aiL3z28sx z>2go&&xY0M-8r6M9)vMN=<+A8%J)j5taU{$II+*YW83;IG4x=41^0~7`FT>AP)B0h zXZRZywmju}v`r9kc#y&?V4lHK85J;Mxin`@%ptZCPiJNahQ<9A9EstvnZYp|b;=YG z@){d~DicQz;F#yXdGlCz`&NJA+A)pYqP{mHg}GezITpJ_@t>q%m0IUd9PH86NBKyVt&esJ3wyFA=8FNmJIWJ%(Ql1DxKss0+a zQ+Wovu27M0waxVTb$8kBHPvKs<0aT$rk3S>ypghe<)}?;xrpYa*9u!lix2>3+8~#u z-2YH4(T_RJ@A~3#1jAI%6l3T)S}FRKZA$J^;Q)q(Z& z3EvMB(G~ufMZl%*RY;HXbqKRH)o2FqFvZsi*DVR07+IJyf@MMjp*s7nGseILc(k{F ze|3JnXx1k_OHRV0aa%@5n@~>2*e0ycYbWKAU=2@!tV?|Up-G%T#jS1&eqZeTE z4s)y>(Xoef?o9)6opkzVGKGnDl=w5%O5#(J%9DR_j+@P z72`?GL7Yxq=sq47QV-q7Mu+dXL=0h--E=SYdYJj2bz_c*i1n2ZR-1vkJj?2o{5&fY zUG|uRO~13&D1x>>Y~O0kTB;Ug?fICMz!k4qj$kUBdB$(yRQwN(Mc99G0Yu!96Y|22 zhmhB;p{g{$pLsa3ztUCy&eL=(<gI%~PiRY0FMJT|1Bdr^$eyg0pw zTm6dO!HiT4&4}$)c6lBWwq530cuclW=mJ~f(7Q`_ivlV~dkJ^RIHPikXo68enr2y@ zp$iaw;-FPVlgMJ+ztCUA=+ot^V|PAtCjk8AB*RX`wxe1nfofV|AM z-)4th#F!~x^SR-3gXNpX9Iz3HWTvcVV}E}rG|(Nyv)707Ip^#J+(=(i{vw+$7p%E+ zG2Nwza&U~(tk>U02_NkTn91jAt+)m}4y*!5R}gZJDEwF){X_{kJNwW8$}rFvlmMd! zIoXK|isr>fPa@ynVEF9A_USI*Lgs(?Z0G<{$pDfUJ{kv@(0^Dba{2?}66Bk8MN8y= zm>D1}9vN?1^D2y)3S3?09<%)J|5Bo0fTo14$tFYQ$e<30l16oBEndfQv+jCNHq*a=`PnF2KoX zqGb@e&B<%(y;7(|1GswAHO*!l5$)XQMB{u>41(#qKx(D%B4*h%+Fbi1ci_Bo!ei z92)l%&GAbhOeOX&PR4NLf|+?g+F0=s4v*hoDHg7I#)I9VBTCJ+vM2zg~fIk5|P55GRFqkp5qNkPd0OJLpv$cob{)|1cT{ zf&pwT0KVOdz-nF~QuzMnqkv09#Y%_1JEYoupuV>Tn7|zs zqVh9t&`EhU&9N3s_Sgbk>FZN{9!nm!c#tLZozW5evGKTO5)eqHV+K8kE*I1RBql(J z->TYNnY{Vh;^c>ou{eX`>3M+7ZYx9z9{b9oymI;&peHks0XoHSWJquQ$;nD}9toGX z9;~bYQ0Ve>WphUM-H-Sy0-J*FJvCbHIK7B7 zz^E#*H+iYuCyTK3M=5G%O@NFJAd;+uXP-rkUWCjuL4n6w1tFmndV?qjA$7eiW1D?D z8g*<2rm}OpRE7n@fC?X5bXg@dtqJCWfk`m}#L?`zHB$&$vUkdj^Xp826MWDrzkbB^ z=joeU0L{Hqx!r5dv%Fje6aUHFdjYQ|Qi{D~YkGFdPZVrC58yOKZ@&-oVN|cC2zS=5 z2^Yl1VJt>Rn@S|}5r&8_h%Ok`x*8yS*#Q&fbqZ+wSHXwlE$0)z-SkQ{KuM7ze#}4r z;(~j#e0Ql8qZ=1-gE&Nd1)$n%FlPDs>fM~uZ9x`kQ#wML50MtkuS>&8r<0ukV7UCw zeU`c$RTxHGvNmpvKjea!StaO@oc!IB{1AXzlmTA)D%P!(IPwQtn5;=f8sU=f$ZwPE9N5+ zvrF}wVKXkBB9oy2{JG~>J_7dr1yGEgvUSmYh|vaN51`t{1Ye5y8@v9m*fT`OZ6G4W z&(jCV?IRYb9xh+duq5DV{OVMOd`PyL>(WO9v8V^~q?H>|o46yPm;l+$UF~(P!@d-4 zBK|s?Bt`VAd=F0e+sCVA;0N)HTwePW=u_m!(K4SZvsBR!(pbxv`a^n`S9HIfG(`+T z<$l50LNpV+9U2}#SJ4uo$JUK-k~o0#1E|zRlU!_zdztf(8Q6F~aS(v|?@SoR1c^lv zYZapgeKvDl=0bhIbH6slT^Lk1hij{`MDml{b2<@whN-6wErE1T3u^tg-#s^3H9e&1 z$j8{|`Pvx-ETqVeb*gJayKgTUq|mjfoZ%2?*y?!e*pbheY(0Pl62T|50)jbeajR^v zoV{&FbkZP+JPJNvV*ma1a$Vvo+djN*5d+r&qPqG@pwXUJiU10iOtZ#Jh_QXsbPbHZ z)~sh`#Pb{HCc5F&Z&y+#PeRM6upJb*!`F@u#A=O`R1h``k58zluOj`)8VP&3Q#?s*c!cQte_ z-_nXg3Or;R60I50){}-&-&kHAo>-rKFK`#ujPw6v&BUJdb*Z0}j?~?);ioNH#>13~ zit>P@YZkDQ+n}u|1NWfAvFd?=aZUnmY2yJNqQ*G@gW5AmNf2v34wK9r8OTKvJwIJ$ z``(bLI48BSHb)BNo_I%cXm*=!)sjEORbWbo@_MM8<9zj`L{h+e-cwYOc*?gnaJ!{2 zu6F3sg749z&72ieZa^)!ROZZ8GLWXG8nH^Q9*&IxRo zRYqp7&>@in0C6nhw9}c-1ZG z(fzzF)+*K)h%ItdS8{@$Fupr&^KKxD@XB?vJ=fhiYLqK`n7bav#Js{No~C*Q86xX!^--OB1pqCa74Df0NG`t7c*Y5^yW2h}11$_eS-0ej8P+ zz5kOp#{?rkrEk@hc``|^LCb6FE>1IM^~RLy;-s}7WVo-wU$Xo8!*K4$95Ox#CyX&_ zl4NxsYhwxvUV@~$5A&PZQsCo*47WNgiyX@_d30vaP`PtK)NU2;4OapwTXd)^+4K(e zx?w0dGjLGNIUD|fy&Y@S6bSk({RKFDM6Z~YVA2?tWHfMD|)~1VZyze27 zX6D=w)9NEu0W6@Z6X?Xe%F&+wD!9ZyW{MPvstt3>&Qd$)DERHsL9IW-<9fNfbM!!p z>LyR1aFDmw72DZ2_HHxO&|Mr9j$yVr|1;TZNSzZkktBD}hlNqV3x5(Zn30#LS+5Pj zpnnNSuhfrHIMCc0XwilwwnWjl{(FARnpYN6JU0^Xq`tfPl(?<&fY+MRr-d~h(DM&9 z7>S(5rS<5Y;I_Vpkk_v8PhBm?na=iHBdPi~-`6+hG>yaH*bc9X9q;)_GB3UF_IU1N zYe)r;weZv#&QvwYR?&P8I4O33ZFRCjrGSi<35Kzy84-V^%|#V(6Cs@^#;_=mi*snI z_lx&Afx2;z=`Nk<496+%yZN#i0q6EoD}bZgr|6Esl3(e?qz|v)CHEDx0yjd37)|^_ zd?07&fIC1;_!dw&+nKVaoIyLmZOOH$NupS>xYV>~UKgV0KN1wz;#LHt2%MsN({)vs zCy#}W`1TA}krW>BoDrg4psvCjqU;Xcz#r1AMyaQ1HylFE1!Ck6y`*X1HSkI+OfZH$ z!=SuavF92T%1L>D@I5Rsl>no>~}l1(|;E)98<69gnM`Rrlbds5X3Qcyd^hwJFb3Ryc8HomAMEQfN=BgnAJP{5#n zlqmMbPBSSCjAytR-RsWrK=g4Rt&3+BHdS?Woeh&i7z4j&igZh@H9?u#ZQtUSiA&S ze%*-^-N1tw(^n(=DNrE!e2x$;{yGGm_JI2aS!N4Ca@%y0>{smf|NQZPq`(~rr;1O~ z|0HC<@Z9%{t9?I=!Y@cc&JAXE|9@R1v=80+RN3=dg9z;)iI}5VzEA9XHy**Vpa0r! zAe?Zf!cB*HpVI@U20DepJ(9JI+WhKT=@sN)cM;|t^4~NDfKGTJt`p#Qpo?7)6urB9 z3(_?QV7(sd$d`YF7-mq^?g09$H~%Zf4H}D&K99b$^797BG78{CxXj_={MX#%MOpcO z7i5FJhOWYpAm58&+XuuwyYFx19m)Y5?u99R*=*ES# zb>o|0^niHA>>)ZnhJ@?zR0KWKp6k%Z$U7;B{sDe{GJ!VqzAH{jL^_yB>xU2$z`ao5 zOFU>WE`u@E6O$}Pq_t}3f&_Q~tFLk`{ksN^w&Z*g;TDGa8Tv`eL@=!lCY8(NM>;4A@2R5KV7+9H46U{8AcuL=7Q;urSc#K05O zPF#E~UHOT*&G+9Ot-(J<{7{0HHmpgCnmckBG571s%-?B&Jd(xOPVax8bb_CS@~7AX zWd!H<5}9FN8ecU0{n{NQ32b$XK{Pt}XdC$2Kl2s%r=l)gpPt9~EI0DC{~_aoaTAY+ z=VQ3Rg4iUmdW9jy&{VQjnl;4UxpFc1WkS>6n`R(d)(1s!RE0bsU$E?O zxbW81dK26SUP6?N|MzS}P@V`v#`|xsggvMR&dM!jKDpN*UZ_Dr%FLU#(*K6zf3K5i zjGSFxm)0ptYJ`xo()4&17U`TXrVf^}xR3kWvwsKvu`gOXV$Tg$m9@u6viO(#1Ii|Q zz!{{jOTE4N@9A7YPA#WD+nD?v5p8s_{dXH=T_b!Km`MnfwF-E`CGcEmuZ+;}ZVs(iaN)RRy>k!c7 zT!PnoF6g}Wmz<*Gr0!zFP3hr85#Njyrw_b{B;`sF0~g-C*yhUt@LWkkG(a%hEhyYm zhKNdOHo#v>y><@Sm=JxIJ78h&tS*A);3X(pu+!HB|IU^FoQ{4E*oKOz=MaL(cK~ZF zEmPd$5pU08NGG2*Nibg}q7Yx3SvbJM>b2v!PUe^dUAa&CRskG-gi(VCeh^!cO+Xus z`GjJm|8A4Ia`0kAT$EtkkFfphEzDA0ZjxPv4Wjg{8+^jANklF11=O7TTIzIj8bA+= zJ%%lC z#&lk5%w^joxY2)0LUDe_QHXhlCW~U7Mbo1ju-Ux*giPmsekB8U8WF`5Vkdz(KkzCs zCLV&-Vvp=3c3NilF=mQ?H$9MtTs@HhhDvG{93pSRjq(u=+bYSd0{*8D!L*Z~MvLO3 zmC+hVZ2)oWwU^mAY1z{E+36P;^~z^JuiouL2Kr1KLQHLTJJN#uRR>Zlk9O|ufYt*-8KG|eae zr?NBvjZZ0xE!GOtm9~N9!$mOTly9hs%-CRn^R}vR!o#G6Hqs>KV@@40gR8n1$$0|#LGWgA|Y7EjuUUc0t~!t zV*rV9xvSPGXO94@DPA45w8@AoGQ5%~`0I%HX#7dKa5W1yjIuRz7{`yNmS2rN8^OC- zrL&J7=E8qd?Mc-d21XUcGhL))0P^GD1Bqq!1x9F@Aed^;UOs^8GE5SitS-vex+d8>tp;sRb7*8%IBVkLdrbxHSL_>PENoOPs7UDgp z$Lr_yY~VUG**2kIQT{bXJhqzAf%P>QkVzCY7R~9W4HlUQ@g|6)%2pK(*tOk~Es%9} zOWlk-h51ybC+HPP^+1)47yVpLUQ^B#>P%C6<=xVGPs{8E9!_J#{WCT`F`%673ru)$qq^VvVWY4v=0 z)z#Pnd4JCQiu;8X7nzXolm2r4Gj=SgGo#hva|xA6m&WhvCN*nUe(*^e7!Xdi?&wK9 zwU$#z1|D`_?Umx8jjazQ2Cf$^IFJikvp}CmsdU{GzC% zu1xPvSf=dk^c`c}cQ7m~oF|N7pDrCXHXv3V{Xj%dd_r7NAO7j0El$|0-y-5~M#4Ra z5ScWHW%9@G3kzWt`drSt zyjQQ_Xr@&|M(dV9M`iNyDScj9vxt*P$Ib|SyC;%6T@quW>81DU&X?MvQz5^256cpo zADAa6=xvmAD7)ol8Cxh-y8lSJ8=d>wMCtmudpVb$WF+=$luO_z{=imiSY^1ox6tK& zm)(VnD%1Un-`q*#k@m+XVj7ZeZr+E&wL6zw#FIH3{atp~(f<~DP+!P%aCZK}tA|w% zJBR8UITJtzeIjYaI&*;V>sCJdEMG~&uNPlm9I|^?{P3Awa$x`A_j>E4!ylJ^wrlp5 z*j5Sb{3t7u`A1r zIy+EsqI)@P^;g*CSXl$|_OXmRWu4^*gTlBTU<(PymtAzh_Ili|zgl3aLSwx5#z>%lR&ka= zY7A1}**}p{+%s&fBm2B&!?uZig~P8U`GUJO&Pe6*hHGoy*5BUdhY4nCMexd8_$ql# zti<8_*Xei6GtbE?f$XW}@y|~iy5IV~ALSejiu zuhjDW*zKoIza(NOwr~H)OD-8(s&C|LSIactqF|ku?3KP{un>RnDKat7a6nP`+vf9*<4}+01FPJ=k&Skvp@6*{F zjTe5Z7y0ZCcMJ8n*c-kt9WH?u7*7&um#Dn$IeXiDGx97e+9DkFnA3u}$qtGV}i)UE-rkkSh zaAh+i4a$qNvKHI*y_%>{58vQhTqt^BTvBYJdM(0u(`nMk`FgA09f|F&#fdVV zxBXeCls=_zhjc}4q;)YCTFKsK@V@u3zol$G)ERvY2PNg-A84byg2T8@e25_SR@qi{ z9k_BOg^xI@YPYGElI~Fa>cPgtol{O~vI62HA@1C3`zBhbarc+vyRyQqWiw;;x+I+N z!_&%{o1Ok$BV##Lm9uY=^t|FuY~dsahFG}Upld--ik_8Frvlqq1ziq3BTs$r@NztH18Z2BQL!3`%c&J&D!Nl zP2&Ukvy|nzBc&ls)3PV+v_D^Ij5>r`#Xg_dc=z{l^cQ*?Tvl>OJb8KTq1>ylq}HEN zPgHG~&*d3-O(wqQebHYOmHKf=_2xaV)IuJUyGQorP7i$FJV1_E-d%Ga(_G!Ydn|xe zHTw79XJO8IkFB5iW_n5IpF|NigNZd26zb5+$GN@vM$aNrmL3R}E&o{ESoeR+a~!=7 z&M-XlPOCIrkxskYzQd?oMf~{Z`0umacb~R*2rFzqP9rd1QHW;vaEad{GgBl@!VA)bZ}7@Tqk^B_B%p=AU;R)13nwsy}-VR>j8zXy*{s zGij<`oGEywa`~I{^}<_&1;%enX4A=9vl4D&b1BE1cLsLvzwkV}JZ@v~)8|)ls`@K? z+Mk-I#4l|J=N`Cfa(7tEElK1;a1=kj*Tu5iHu>wV)jW~j@{_CEi$g4pU``T4MB34CkpQlckNdSNPN22w8t__f40JMrpGO z33A?8toa)D;PChHBJ|Qa7`3I2ok%>W8|`Ypn0U8j=@{36b%t>-d!qz4H>^*Ke>Zi{ zNG7S5eX0JU*jUf-O~V{3^P1L*;9qik$xLiFGR=o4md-aSKeXrWn_v`@IVq@midL^x zRC({k0jl9}K-h-#^|=z4EuXox*$*idQy$3L)IWy3`-1dM;Z*`uXWx;} z2=c+>KD9(1C8b&$Gry4QunM*0S8;mOlvbuz-;gTnC2mx7aYb)B_zS=J_M96#l;++R zR_r*{ws4NeG{Ymul<$oCot~FJWPjZb-lyi{6?Sg_IjWq^7hRRvc9VPew)@=nKYAMd zN6>U^>FL3OpVL~gKO;#6PV;nS%dLWL9ub6+(pUEq{>a?>fJ!X6^UAC_@qDL7v=F<| z7~0>mrYT6qw*Dlq>RtbC&$OLcRDEb#%Oj}}?AY=Hy=7N?wtPaK z-4J~<x7o7~1omMm{v7J3Tg5Do8}E6r@rv`> zP@duJdmF>IhSrm^<-gC4h!?f;lt0e6TC>FW>!nfFb8#QBmuZhDOT}K2|o<%523q`4@^m-{ARvA0$k^167lsn|F28a5 zonw3_x5QIw;|?@&ca*b8+91oRV>X#}g~M039X(!rJsorb=Rk=YnvfKCGNtlR@7MN* zTiHmo?BCE%n(&_=w~f5mPOVR6Bb2hlaCh0Kn>@Lh7qu~-H{>3)y1umeW$;%aJ~R03 zBOAA!*~V)Vf1t7#CxHdr_5EeBr>R62W5pMaIE>lEhnGb4zzl<|N~0o^S2l`?$88G6 z{T%vA^pxc1$IerS_S?_1R6I3o4l~TSicMTtCq><_ZkQTfB(g*_+he< z-}lz@&U93TcUc=A%eG)V^DvjcWXjUtvKc#NY}1t|sy{B--fjOxXs=semqD-7_`R%( z)Iz8T+!)wOu;P59@oPzGkVD5HyG{L2Bh|rT^Vhd)YqQ7eLz`Qqn?-2?cGY8lqp({# zBBxW5+qm0ga~bPjd5m(Iq}kq;!i1TgVG6tb=RL2I#~)wrVLoeP{?(TWJx);P<$|$X zGb3}eqE`adWHr8#1;_ zlFBY07-ggFzfLTxY$!|C%qA6EJ@0Pi;jA4k#IGgqW@G!Du&C@E{o3QbX^!E`u)&t? zfwWaaDskwW3qy;zn$t*=+mDyE>c~hso;;uO%f+&_fxls(%^_N$e}S+4AWxoEVBPXN z%~x079zZhp@ukHLbdqv+vx`Z5+RB9{v+V|e2%ZSs#f55!d27YE%Qu-Gi=bhAJB z!0r8d{j~ey^J}BwTFUnn3U}rPtMc+j$ zuP*si_e*Xq9&>3`8L$<%5^4`Zgfmiyeep^0l14RiOXS|i_;e1=_JpR3C(>Ty-}+~> z;^p>rc*}lCj=3JQ9pakYS-Whzqs)7O(Q_`%pvi7xYn|i~#DRhJcY{ljULVs|mXChY zs-%A6(<{Mu=%AN+v4n>j^^BkS^#&JuJ^TjpA*DziR2$3}?QeAv-gx_VmU_%3!R10K zdLOxzpgW|Ba#@ z_}3T1zb5a6*HAoXp389>rK3^TOfQ~l`+L_H#YYL7!Zal*E{DsHLDx!erYP!jz_lLv z*6E{yegAgBw>L?yO9KVN$enTXqJjimYP3Gx`p;fzoA;K^({pKm1Rp$!@x70!i{QS} z$HhPOWMH=YP;1`y*Ff&|RHRG)pCGR~t~=FSqQ91#PaVa!cc^%Krg&Z3_1=PyC2pwl z_6H?rZUc_T#Nj9rza7kiX6lt1J@mKYwC@Vn2(#x~_HLm|=ij$R-R?_(w8*_k^R5 zD9m)8CIN9AG)x6|AygshKiZ%4oxZ%K7n+t*QE8tdSRZ=jV(PL_Ftn^}Kk|$>d+nCp zI!u0Ctu^r7RN0#G$nu6~+;iK`DlrXPI;;5y#Eb0mHs@5&vUn1EAn}y?DNK+%*7Wi3 zH-D`E`c4QEY{~2GEwdX;{!u%PO5RnkcRf!ywF<{a4NNBk&T@&mmxIxVtI_ZUX9)`%stc9v*_(YpIluhh|tnv9g6JrK$$$g-a zT~TW4Ewuar+{}1DgthA)NPmBK<2_x@`4{mZCEa`ScVeQ!BPFJf^%rI)9GUp7OD-L` zmO2=naDU=xNAWKg`eM^T)!R4|HIxrblo87zftE`2%eCm%_(RDnKil8uv8+sXX3@52 zg3ZgR_rFSrWlp2Pb>Tb)hp(23=*5admF9KuJ;^kF=kB~Trq{U~@iPcPbp?*OpIjPm z=708TAGtQLF*e-O=TOZaQ^N78%A5S{skz`<`o`~fobMM{T`wg*|3St3CbF)B!6zf0 z3ub|)>I0hHbh=<3Z3YhCGEfmJ@wjKf*ryoHJUAA610Ks4qGcR`Y1dnIuL?e) z353u6P)aEVa_a^-sd@^CAQYkD_G{x+a9;PVml44ToqYX z?HL8&*EI<)Xp%-nx?r6lH3@RylW8i(S~?5CK3S)dC6IlyucN~2wjisE#yM;|CBfv5 zo+n~owgl?225`T43g&NoiZy!xQAhy%EnB*pcs+{J0NTY&@Xl5W0b6ff;CDXOeE$a; zcmmE6%%Q@_($tH*&rrqa2b@a8gFI6YXqf)5mZ_FjId<=E&wq8h3KlZ2_SSLLiCh}{ z{7@f^f%aqB#iLbswmPTp;$|noFTLd{zzCn$-*lVA0T6ufV$(bTqlURg@+p>JBrmC` zX|YrnR4<^Jk=vC~UUsE|Zjp1<1O?%j}cJe;2>@%e{$GKtY`}{}ADHg;9z>3>dVTQ+m zFk2svC4TjZY6fFc2Sx|W-vDa=;(dh5@PjOaG?V)F+8{>S+8^L8n-u>*#{7ifZnC_< z2%KmWFNEN3Y2}zc|Vw`$ANgCGbF=Fjjhu8)tTNe>+^$ERa<~icNAHlRp(=( zx;#Nd#fAV*U(PoTdX_ZIM`0t*U#B5d0cq$!HBbsO@t6!+cYpdi?0Bm54Gp)Skxp>r z_QLuGls7W;I&ydh)ZNJ}%xCCygpUIs zUWmAvpwiToFp)5E^fYWCF6I`4$+O|zF}}vI<@cEHwbIGtkX{xQPI3tVe$wQkhW`aRw?!4w0pWIzF}yMW~+!cashO z#dh4Vnjmc*HiMZ$N9#aY8p5jbjkD-%0s8*+L0((m6P^2+z>rCr3sk`c)okM!wVSb| zxZ%g|b96Mv=-9XlN6^2Ap6;fUi^!lV9k*(5qrCVT*bn$ZeTkZ+gGHXfa#F3Iup~Oe z-Zqh;YPM0&WL7kN;A2bVP*mJyys5x@)4d^0wMu6Bn?9Q|?r+$!bMNEA)}KavbZcJk z@_ez2eLQ0x<1JXUeic4N0~)9Sm9aBkepYlwizD^&1-sg5ETL-lk){PnO33KtbPd{5 z^l-%VP@6-9yg>TysJ(ct%e;DBXNn_(#_@0)4yzl3uRoh}Hudc4$!QH^Z8$Vqk;GmG#o3!bpp4MrVqfq4CBs!=eRM&mLb40}0&`3v}kk^7|Sn2&)|^b$#fBY$XGv|U6>6ox6_bL>0PI1cpq>-U>Z&Yv)}OU zGo@GM?BU0=;L#&!YvV_ZftXuCDW&wLANr4r;mN0UM(tFxPx%i8&f>%`^~3afp$smd z?Yg{ZP8&1asolqBjvP!%+`d&$bf0|Rbl20z(??nOaL=!;B;|`1s};SK9)j4OAPC@$ zFZjEYlxSpymmKZ=vrc*PF;56ut2p`(KhTU8Kb3SY#Y{}TX>VZ=h0@Gnk}`ojW~ zLpjDCXE_H-xr_pc>UF3$`}a|PM^HuTvHM9lbw^73!n1uCE(!ln|hVJjbbXVy#a)Hr)_2 zEW;h~bqpYCiql&g?ucQ|o4hBLtUfFbEA$gFn#) zCYE@bzFY{4g8*j#LC$fnWY&QUy5Z2^GL6z3LKKSV=iN?}w<4Uq$$BZWUZ~HTr4=D>wY$~s ztxh5~3-eYUU!>7E8yyXOE=GGo*hGw}MP`MlVNlnR(TBA-0uGgZq~>Z&828qd3C$#e zVRf_*mtnwl-yzvfjbzrZytT}4)Q0BoTYT|Z4avO_apTLsC$4%QKMLL9xG?KgIJ&6B zzCaIQ_Px{NduY_Fr&v8R8He@@+Au~KvHD8ag|g{+3JDN5brTqJYSvdLjy_Xuy>M?^ z$dvgO8`eKsw$g0?N6;3pk5Km5&6D~#xC_^eozxd=Ug7WUzEqJW=B3igKRC6%w54+2 zW9?q+GQW>|>TRQn`P{+?#29H*3P< zyVR96jKFb^)7R=F`gk+#d6}hfzS1?n^Gho(SES9hbRw7Du&?*@$-Qa&tFoj;W|f2* z26T_z$GyDs26|0&{xz-Jy6eA0CBeagyrYWx_WK(H8A*=o&i|m+YWWH;4cC51zIEbX zI`C$)_{n7-iE8h8BubWoO=HVHF#wIC@%cnwd!Svr@ISJi!cE_9#pGSTD+;^2a|@pk zou)s#47B2gLXIB|io=xJslKzTB=^0b>aSrP2ScEfMN%-#o3EG;Ez82SHhGC5%lm0c zW-UvayR(u*Qs|jz<~DsVAK>GogB|>aZ@qrXDQH@Y$+eQ^7Q0Q)uTgBys~b)whh!(F zol*NopknB{QYjUcH7^IBCto?!LimRMHr^D(&*{K~L49kb$v-IsBL>N(mYu z72g(t68|`iqv|nhc|{&}5gK+!R%q4lm-3by0w~w{IKc+{n)n@8^fK5!z^04WaM@a>=s^3=GiW&R& z&WdpPapK|u9^68Z?C*tP%}_r1=Zrs{!raA%!nbdgD0T=TcM+kPK&~{9co~5Nyc*`~caGK6tP^g$23Pp)VEGH~c@ct2-5Q z5HiI=A*Q^ldywt)M2-u_;}+_DitD7^HBdQsdDz;E+JyB%dfBdMF^{!Aju18tuSArASlv?T!TO`NL{(}9pVRw z;SX+XLg7bUu+~_qg5A6O-txrNZ8~{1x}|y_mgt<$lYhJHw+nJO_Udq*&fLA%KP1`? zl$Sw&bbz|XIWiQ&A{Yz4M4ut;kU!CpcTE*Zo0Qh!XgCJRL?LJ&1D4JK_BIxwE|F@{ZGgqkn9d@N=3zVqR!O>1bg+BI7rblaC&?` zOo@XKS{KsWfIr?qrn$2Nxs3Sv=jRGc>lp;;w|#&zBxCjoS9WiwlJnq&lX*GdLn3KX zPOO`bl{m%MH4kau7zhsA%l7Ir(o&$QX(8FbZ-?)1&(H3=^!oA2)E9kPtp)B_ zB7%LB?h6@%ebgb?;lCVlasU7<6AYgG&%Fk*)bnk~F=FLyzv@DjR?Tu2RD)IO$Qbws zJs>)be1faG3Q!*6wZgK$!=y|Ao@oORc+lNnUKw9;RIMXhEw41x9xFO4yV;w5cA^Mb z;J#?oBLE(DNeP`QiTCBTJS^wrkJ)9TG~_Cl3qm#TSXoOl7ZRl7&a2hkgu5E%n`&|G zrrdgP@FF>{KmTUPSab5l(|Iupc1B>;@^sgf%q?a}5-7C$W4kQ(ZZH?Y@p`n#yzoT|f4dPlvQ;Xfc+#OhiFgrAdpiIU}U% z^KxVioZ4AWbZaov8@@SgQoawoiM%X)ss09nLTtO}=ldI28C3ACP8-O&}wktbVZG+wBpvZ}Z z;)BF7(NK=7S0EHe%$Z&qy*Y6<6?fj3SaYeAH6{+EcE+@thYH#Hv+85}k1xwreV!|e9jQ~cZBC3uw){8*ONF;X0)|a`v?(b5 zYY^T!(*<~!x{UoXczKW$!by+z<3A{Wl=W4|IvCa$?F_`ZQ%Q1IcEUcUj}_VKuo*-R z>~```dI(Bfm23%G+FuaJiFQ4E{75h+WQbW+% zj@}Y=V{};b?@_nu1$%`T%hh`T>sh>GY&C+e0vEC*1cpJXmwTtk@ImHKDng19Qhq*Z zj%WeG3fjhb(`ns%n*r8njCOqcZ>y#Oh*^9Wcc{ch2&@g35%inAy7de>>f$1i*jY$R zyN^1FGWeCa{%B3R#yA$)jU83G>X7Ng8^h6+^Y{b0eo1#zW=clM_2*+#6myRGJ_#K^ zMZOvTM>!@H8)q#%!79P#vLo|=P;jv8!OY0?0W>okS9vZY8)x}mrkHd9K5WfC_PbFi zk3G5T_=6=c7m#S#85Tr+n4J7K*ca5_CpW|Lcsf|Y-gkVvkb`<8EOC2gwpV3g*DkI} zJ(Bvlh(u@Ze?1}NA!^C|GIkYDsy<-JVmb;)TC9?r4Ht6_ebm z^~vuPJe}9w_rG4zt-jVG5LUPN>_GV}0D)3h~{axX`SFey(?zV z84f0f8vg`(5B>YgWj7A*sfO#(?o-tI7Z0eyx%r~lEf#YJh&yH~REiGYLYJn!{}6sL z4yN}mY!3uOp~_$VF7){BN>SvOM4fsavFmrC`MzGZe?1e1pmdE*A*A3ajgrf%iLeom z0z^|cM(bbb{oQJ3)FH4DnpQh_OU<{1%FOL2aY-9uCYc8_lz-9IeI#bYBy092!wtOh zSaTWqY33$gFOGcq*W*g`*+C)`PY-(y2Mu#flzHmA>= zf*VwHhkOl+-1Xa~$79jvUl+X)wBw^v8v{S?v8-i))?;iu+mKH&)81qFzoFNH80@2} zwd1E(CEx}F)0)$@&_cyq(OS&^ z>Fvz_sm$L$E?egijxAX_or9udER|#}Th{DrQkKe^v5kn7Bk3f`(o}TD5@V>u*osP& zvQtKA8KIeKWNlOkeP4HHKJ#7vfNwu|^k_NfzTfxzeqXQad2Lvubv}{GapL%bsgDIw zKDa_Tt$lZLIJZc}Esp;xTz#J(;_sQaXh7mYEV8OeTN=k4Em~lW&Q<7~ze~m!d|r55 zUx6zWmQ~+p^X(SBMR1Fs+*f+|@Zc#Osvv1LZN?&BC^%rwS)*fwJ~tnfYaim3AL}$Y zI%iMh2ZTdN~(dH6A zXQOH-;xZI2WMow&;@G-@eA-@sQvWNIHLsl5Xr^e4^mTem+S)9U9gFs0!&jc1g#>H` zl zgFSB@k-#sFn(prC#)K7wc9z_X0&>BbdV;bWiz~U#u&4(0DWwpf))>eBbP_W$$JR{? zbSoGo^12>6n=t;M^O$h!=_di4J#LOzQ~2zMhzK7#G_FZ|FZA&Gz4brz{IJ%OT?90f zsbr^yO#Dp!BOfeaycHo1HemN`JNl5R{=>j6QAqs;f~c|eKoLCg&R8CIQ#B@Kk)AB+Lo6z3br?uJC|m9-ac%z@bSG8 z=CfoW-qD6W51N(XbJISOVlkuVekj8H1-rGu+OGmiUc7^a7odcWB6lUg$LjruRCZ%A z8lDNrs|=KTT?DE!dDsW%0I6luvy;E?DO5G{km~Kfv-hoPMB@DJhdd5ch&bbpT{)K} zzdW3bzx1ru~r-t^17;3H*%8;`Q7keHx)Jzt5S~fa^t?myb>&P;7 zU@v}sQ>;B=zV_PzXGkNGr~ySCGZyRuE{pW;o7>O}0L18otUJ%`D69kwOD;*(Nh zsnLQe<34}imTPdkKcn*}J-3Hs545f%<~+za>;@i{_|Dz_y%pf6Ji>5&!IZts7{ip+ z&E)1)K~AXwrzA&tx2jmw8S}10GDP&`c`WU<^!ro8vvDa|k8k|F!FIY<{aVaXbbV#m z+Fv_Sy^SR)F{18yD}7XPunm@~t%Xqpa)1)zVi@#P`FdFbLJh}>x zFra<4z)gp*75iJ#Xd}!8AyXeh(L#uv=PTz+);28Sa}d{_#m0E*A<$RWNN?$RCp*f^ z_o+n0XiO@Jx0$j$IsZwukVdW9SVA`{6YPCt?^NmEL$#S;pzB0`H7%todTQT1+Bpyr zK1M!`pW+jUqKT3LKzxH$4w=J34;Nf#5^N~8loOPclB&R$w|-3#XSEFLbNheZ9a;*e zxKuKp@y$3S@XDLPq!T{<3BnJf*`|_`-}aq2-b3!aA|Roh|8YYAD|t9%FDt*X?G(*2 zN3(Qwxa`LL7uCJ`bFvVa8#9=rikZ65ylE6lli@Xp$>M#|W+02-VOjwiVL*O*|i-m333lqj0T}JuU(_Nz1*9M=PJ&@^{DjO${#k^6gtTPu+~& zVgANbTZ?0&|3hAWdrO(VxaCwIqsZAPA;15@#995LRXmF|dGY%V;y4>8Oa&X~<_-O~ zORH8K*NIe)Xpy~Xzs={4WJgzr5w&4={P79@SMELgIP&GC6-;xPO3xAV9Bq2+f5Sh{ zZM)*w{WVAVi=aE~GWX%X0%Axo>~*UDc$v@zmb4o(A)m<5M%$a(Z52*KnJ6;fR&f5i zC0+#FW|!lQA#3a^Du1((hp4a|*7=U}(k4Nb zn3C3o9u~ug+j`h^pZt;?_f2wH&?rRe#pu;1etD2QZ^f^VYi2^Rn2W^-TXTGWIs}F2 z{g%r=nc~BL3} zN)eU~oGQ!y%MIiW;IP~FE(gU#RyczX!X%369Uj+sQG>xu2=Si5h*nzX=bX9%#CRPm zyp00UsX4o2ZM==?=c>Yyun5iruO(7gUi@6(v%G=g;NKk-&aFhvNC3-t86n~K&=QQz zC218^-yJUbcf`~m3E+NyAsI!AolgiT=la#l?-lCv9_-jWrTEV_nZ^zz@6a3iOM}~w zT?YrpI&XWtB9vWECr~qC@evE>)A@)2)4joN*WKv^w8FvIhWb@yljKzyMV47txui9` z6Lzt+9Y$Nusjn$S^hw4?qqz;Ws6Qi^+>22~ccVcF3Bgbo|E0DMXeIX3cUD3~>2(xL7zPa48JY5k<|A;ua;NR|6#57KNLcQ#6GUDI zgb5=L3hvz^?35CEU|j$^X@%h;dShj(xT$5Mm5CyF8PeM8Yl!@U=OX?2NG(WZvsiE$ zD-=5a`bS74KsoZP&Ty(oCSVW{Uc};DcvQ2&Wx6d9FbiuRsW35ZEVa#4JY9U$=oYqw z`Pv1?w{NoHRf!y<#xoMDMh=o9iLYR7Iqsf)W?1dO4nwVFhd{= z6rO@_3u>H z1C%-UJsY|ins*3=XG7@igDzgFx-SA)^^Ih&yRkLJ4n(Gb%#nn`bFN4?AS4?Wa_8}6 zcS8x@B_np@ZW#NDm-SKBKYay9ZN*b2dOsS~P{W4lyt!_XiLK;ANskTnj(jLesRCyk zV{7!Ip;hU9!CHViC#~;aDtlKT_$w(DAggsK9aWV!g&Hh@pDnY@_{f)VEhu)9gGYeA7Z=gzxqid|hr{@jqjl>?i?i=>2Ss2;YGo>E}B+ zN2e~#$0S8;?p(sQaPAgqVRe}+Bnp4P!Hx?;Vx+leEGRJn`4vJN-6;%~94#ssn2jNW z%%u%*8thxvadX=(g3K)RdC=~2J@HAcL2JgPI5;HEG~Q5U=X*!zE7RhT<7atUt87^& zZ5eJz&G&r9WKl^?QZ9^nju=5E!_?gk*U)0Q!5GHu{|Z9gsTe^h zJ;$pIw;R`Ox651+1Nbug1Gy0%FEY?YhN;b@wZq?#U^pu)_a-*bccqU4Qk<^<9X42R zo9>(3DzJP;%fDOB=ILf}of!!wqQ#O~EbTr`nBiQKpFv~g5sF;q!&mLR4mdK%2MVddc1f z%4v6$C4WoQ2uo`v-Pk5Dx&zIc&eCYVlSO&@vS<^6qFF$hR@eX^W8o!qZ%-L^IDbx} z43h)goyx(pNGq@Cv)Xw>E#BY%T$`WFFC;3+;v*G66bH}pyf*fPJ+OZ7^;kVg;E_oY zFE@X@f-d*3es`?a)EW*Zkeelg>>Iu=$4M%pQt=`YzE);x3tF-bo9RCZmU?=#;xd9B znzH z0P+r7rpF-Zu9%x7^vZ=h+?yrdDd@|SkvgneG^I2DuJBy-GGD1P#j0cl^Y@rNs~Z+6 z4&r@Tnus=dkrI+wBD^lsH={;g(LIo&v+nLM@J5Y7`A)oR66B_Z*;A%A+EPw8vBTFn z)}=E}#bFv`nJC@NR@rtg(46yrULwL8k-rDjh>$LcD*2t-8RUjo$R9}R60GMl=Gy2R z6TG@1dkLme#=xkLM`)B^SZjIWh){fi@*;Aq77nYi1h)u<_8ABL=uqhMCim^$6BZQt zEv%GAwZ1>slzTk7mOL}&Tje!Knn~VvroivfxYTarm-!p`FvI#o#gcVn`_isr&I_eD ztd0|NS-va$kR5c^#6=c=G7}1LjrWWr$8Uc@u&FM`$LXzLPimEJ39t89f{t0`jGy74 zBQJuE?aX;SmbqChgI@8zHK=^3D3o(IagT5e{7&;6+ROo}6X z?m;HbozZOh)c`WF)=OSy4L`x>5HRIzheIaV4qDbmCi38Kb*=kWq=l|WWKt!tQk`>& z#}P4{kqS)PA4jmKd#0Kis|Z5Kb><>R4h|;6e1=>|Og7bG?$$GFefGXZ;l-*nJMIe< z3sH20>KbUuY#1beyBaR?+!_4qdlr#$oDcucPo_QeS=b4VLH_sWAi;E0FEGvi{S(e_ a7T1+*SZTbqrazerzbtp!n%*$>PW&&RKTB}{ literal 0 HcmV?d00001 diff --git a/docs/design/phi/images/tensor-design.png b/docs/design/phi/images/tensor-design.png new file mode 100644 index 0000000000000000000000000000000000000000..3540171f0f5e54566a6808ac9e76886195dc6c12 GIT binary patch literal 134604 zcma&O$Ik8AvL3c&2!bI%FJQQP0I6{gIGNdG&g5hcoeQSPoHO6RJvHoiP)m>EzWeU@ zEj+mQQD{iO#hPn|!8B4;ef3q<|9^Z9@!$R1|Lx!Y@y8$kPL??Pk3atP|MSNm|4aS9 z{u|KpAL0M^KY@>bwb~-{$AA96|M&mzk3as;|7+vJt-pSCSN!oeg#GK=-(aX3*X?fz z_cs_$rs?l_os>(`uYVtk4K#t{{g~Iky2$@J2ZmtypCIg?K9CE+e?w^OeM5=A;{+D} zb$(J7!#{3_{T%`~`1g8})x$Qf)!z^X{FaBU=)fmv2JfZ=-a_#I1O4*{`)A@mfUop4 zbwyC5Y7P1zU2@^{-~?Z9#Z2%m0{8a8dn_30&I#4nDylim*tjgH^8D`PU7G z1PtKMh1kX=tNv*#hIL&2x)tGnQZQSL-V0y!^V>W9YgRMvN(&)^brbui1xJzHIb#hL zqK5ta*uA(94XR4#)Y7*8D(*bzkBreJCy*Gzp-;(wo38|>h6!~Ytcaj5gle24{RQ8P zfXO(?&5Z!TzvPDAuChK^dB=lP#-AZkb{)(GBm1bbv)D{Bl9TY|>6Y!8su5oBqgA03 zrt|0px+Vxwr)NS#eb^wE@1^-fY%}~yiqu!?mO>c1f8cIxxpmo3Q~I6iQ^s}#(B|d} zwt=1y?TUh*Us}|&=pF4fk?tkR_DbjuAyVIEfIZ|^5Snp9f-+R_R0;e)jQF$FgV}{a zcncyQ^l;+hSo1I$`IGQ!mig;LhBpTGrjUUDa$PkF##T|)h|PRnwuC8fJMFsL6Md}% zwT?pf@SfquZ`63z)0^c7bO`#=&!nQC9x}2E^%^_P~4U!GN zz}SSYy_~+u#fMLVdRW*74cU%hE4eDrhUM?uiKg$T7wqgD%Ft$@Qw|S?WuxJu$?0=Z z{7|hh-n(ast{?1)rsv#$JD+*{x5%u!_>DR#Yqlwe!3YFRFuc<=ZtU3G68uVYW)ZpT z7}{)yw86{dK3?*p7-bm@Clfi3);pOl+gymJ%RxAOdJiy)1moS!aL41=SL3o`6yGa} zKM~XD@-*sDZ2J+e2cE!#(=n>$wNKZ8=QE`M73ZymMwJdgk{NXTGHb(b=J|267?O^W ztgABDWm$BMpbkyBlFhCO-nE&e3$gTSOX;zKeh}fnPOgJ?zd8HpdTioUxm-tWJC9sv ztrj1_^hl|?~Lp`r(gaCLyspaH1V9A*#pbR-`dqyflH}sn?7lO<-Pi$ z9vH2_jJ8=H)`+oJd|SC{aTw@0awK+Mac@l!sWRy`Tk66y`LXF%>+6pIcvR~Qcr$YNzB4L%aaog8ZWOr;zx$?VJ ztng&fX#AUu6x1fd%>#|dIH)F`&p^v@lrF{K*;U)oE9mqgOm?EryD{!aUswExsMj!> zy4u_)1PkvqJW?lLy^FKPu2+fdN&Wunm%(T>xTJ@M^XcWp?XRmXh&juiIX*JE?W?ZX zO9EZ(P2U=92yR4_efs-GXW{T8o=cWi&Ab)zq@hFA7QtQzqZju;Jnf$+dAlNfQ}Dx&Pz5dTr=s+_sw>!r#-Vndf5 zq8HJEa>z#$99{AI+$XGXWq+x^tP?RW=c1l1xXJKTL06k}iQF!0cb7u6-I`R$LG3G- zTgI<*W_Ct9pq0a}*()09syZLc6Xv`2m~NjFep?5ce!3h@emH5rnzBGPsu)59%pbY_ z8eo=$Yyj6XY^2}O4c!+uz5-3*Q(ii%Lm90yJVw*xM&|UkJeXzgQJX>beohiV&_WJ% z87V)!_H&{o5xb#}2${N_P!#KPao@+p;2DsB0>`QbznTS;9Nm?FusO$slo%nisT0oB zv4nhvPcvF375&i8nBiiUh?)Mr_&`AFUtaJO3A#pG$4_ND$$1XzTypg_QOPHk>=yNp zBjjSBFc`{DVOzJkO8}z6)G2SlyUb~--VSwyM6+~GpWA^m$O=CZ9*djzq4}y_Wb0#K zIFuQ?DEXcsVL_*Td#Vvimb+=3LUt;648rTtcJ@87te4K-}9$bQKs7C<@Ai zliLJ}9(p{>ud4)3c735a+b%y=ozTZbHIc-8+vu8yW%W}CHs~j~dOJnXLQ+L2Sx5=sj58D71ku%SSZSuWnhGH1UC zyatFInoD{esbY0$dXC|zRUs|#y%Q65k&v)YszB&g!-yK(i=JoD)_BIDgAr$ghkFbf z(<)gX9I7jA^CMW($$G`DI~Gf=?oP_&*hH~Y9IjvIflDT}j_d-0Kf`Nyxbd=JCN4hW zkvo?+dNsl9+{WKq`2B$6MJP~Vz>Un@+YT)W?8t0_L4+*o1i>#wL>3QK9{m>5PsXAT z{eWvcUuaB1;}TPGJ}nX|ShYUsAs6G(FLUwYJrajoo;-WM6^rkO4)c27m=({kG%aL2 z7Ah%?yp9VXBDNCzh#-tauV?HuRM0>D1a40{-|6=Nh;jV~p|b^UfitSqvx84Jer}&#&u*)GUYt65qg9f)zxi zV%|^GB!Yw_T0%BrQqn*aLh&NIu}OeF6!w8bGMO*R%QAk)Vnw)+Yiw!igQGFt)T9KJAjC}}m{1DK{&^IF&042xxZqiHWagG|Ci@$zDr&{vc zpy!h!Y=~|!fZm93{_uGmPbg{>F82Mg)@Nw+jh)w*x4*u+Qu5C3XuHfGSTC!Ae zj7n?>NUPfGE(K?@K3CnY*$X<5nRDGd)Ez7OsE-TdHe2pNeq%*ZCGt9JW%q$I;xy`n zQyz<)V$wqnDBgbG^9!6s@PPl~FaA%Q$L*gJrfcHPpmM@9l65u(AZ%b!q5tMQ=EYi< z!#^Fs^?seuKRrimMPI-1FfaYfYncS=!%>_)lHL6ZA&BDWDI6qhYi+l7ARm=rdEQZB zCuI%Eor|p*i0@ecH>NquGwwb#y)o$OF zx40xbKzQg^8UDIXYx|m*7AjwlO#}UERj0Y(t ziR#+6OUwwH-KGFM%|y+dk`D_@%_~`ljH1)8CiM5&u$Z|)G(2J)*)uq}a__n{qIAQV zgwfG*L-2fyR)VT*BEwvn3!lO-t22a=&_M8>C@zt8JI}C*Wm1X_aJK(hV z5f_6$nY_Y8MYhLMRC>{8U#22EU3m_vlGH zoUO%(13N$nV_@oqFr&ZL`UEoa?V0Z9`#o3`+@-TN+DL)1TT73=1P>9T2+ro!o-CGA zEhx#oD>#^-z6wW)=e8dmTDI+yK~F0qy{QWDla@yi!uzXQG9&IX<6o~Q`ddLy8I@>u zax2y2GsTcPaJjc!Nf%ry&R77J$nj8qu0e9Fv2ED>AMm-mB+WE9s{d(!2b%ARP3!O;=X4MV5d>KbIp3)Q< zt`d!^rm9nL@IV;IpuwTNi$i`Z=IM;A6|JP=U=(2WQzn3D@K_?-#0?;1TpMV^@hP>i z`-V4RK(2YT|0j#Z&B&V0agW)&Kjby-51rT~R`7XX<<%bx;89XYMS&XV@{XGFX8jN_ z9JmYtdCNLZ6c1Ay0{!-#Gd9ruYnF*o68p6sNv8V{*P-(P(jGi|24t*21lV+pbz-I* zEZ-Hts{{75&#RfSMh0e$6^bh};4#jcT+yK@{LR2miIjlk^%)pPg{5BWdMEI9sv&1W&Z)} zS$evFT+|jiq=GfhmJK^w^}&yVuk4X$Hwp;Qx&t|5^2ks|AVU(FeF9?|H$@J52x{#K z4s^>O)?9Sm7`dg(q&i*T6l&OcF6Cy-OA zZCmyf+grHi<-H$95%VSQR8hzxo1(Kx0ks9}0%NqFXK{CD;RDA$zEQ}Vj_)(833Y*j z5FKiMe^X@S4sg%f07E;Cc?6?0gt8;7wljBMOQ%D*aRZIvC*U;z=b2Pq8*US2qxAYBNf^WGe%r9H$2{+Xsg7yG zUB@v>^?SXI#9@!300E{%J?&6yy z!^6vQ_tBkt*QY#;Ix(wnHjMVJyX64d5>Dt!twZRg&6I#To`e;R8x)vX;S^uX zbUusqJqL2~2x2{O=A7u1afU$Db(9FSCrkpYNAji|7!`J%!vL3q z@h*DZf>X;)mqFOf3AdI+IkxDq7-xrq#Cy!o{X`Y2uVbHE*M1ELptas0Y5@VJoF3rJ z4)Wl(kKfWbL&hA%H`h!x2=@zfhD}(c90L7Yf3b>b08g^+YsBTVx#rPllC}2YWi;t4OJ__gt^!ZKw2Q5A4LuP2Y&AZZn zfNRkY2($*qq8K$MYIhb=6xcNkuUMCS)E>+#x&CQfITW1>E{beP437B_Mz|3YGF*6# zKwds_{ejr!aq2Y0-mgjEjkE*ARp|bzx%V9A!It33^zv0?vJuQHhKPs`42Tc>S^fsTI3E|P_IB-^tE=7n=~HT-R}Y~qd&l-KNPpSK z$(~m;fL8qgp$GgJP&3bX1o3FmCQx+n%LV$T_3L$>>KdVN+NKBs7IXAj^O(p`H>HOp z4*Zdzh@+IFRKU&y2t5wf6VDK&UE!n=RDCpy(EpO>_RBGFLeD>JLkU1dzO#+L{X!)8 zi^KtuwVFhAIgYei4)7*ZGZ7KPY_9!UJ$y)hrC9a;*hix>{ppg8U;OFAj0jS_@bphF z=5(~jb_v3N)$tqkVt53mAnk5FG`aQsIuAxY$`uPPR`j2;?Ql0NQ(Tb@a+J8fjA1vd zH<6X3Ljv)x!^l6lhI)BQT*5i*%1N=QhAI`@Q#h*M+UvudP|wEjJ=nGe6kQWaR;G_$ zKa(T zwUCNK%Lby4S#jLjdp1&ocOm~~=$^?oXmGHi|MF>&sE%I!;w{Ik>F&BQ zQL@!w9L3|Zt~IRB77td(X%&AG>FMXrR(d?HSWpF_yG2dRHJiph3J5JYq810w%ee zx(0<)i|=1=$ASM#z&>~X5TiGCnJc#xv}KD=8Faq`1@!j+8_f6;gvg#Ut8my_)&Nx3rd z6!CO<)k(P12VO7tg{55l=DoYAs<{`iL9djqvLGixoe80}FXX&U7+uO<{Q+JHv zG1C^W4BX@E*%{oDG7D`OG0QPN7s%tKB_U4^0|HH^FwQ&9<4yH5BViK^YifD3pM_Xe5C>0!&19FZus~vngMr9b1h`dN2g{#<$Tj2< z;7VDS*KO9BtMz;*mxBz%Hlu`V5I7VDU;@EZrbZ>ha_h>c24K3T=(_6}dn)k3{<#aw zT%5X|w1_xhz5o8W3FU7J*@1{%T>ArBFwadWh>uEdk^cBPU&fv@rjo zN~>Rk#&pEE&RD>hULQ9x8%s$H&(%NBW2Ml;sLM1nRX~G z7H4u|%jsCy8V!-o4{U6#&h@)|85R2>ZVIqHNWOiakolq4R^WyJTC@##qLPDwSqJ`#QEO8Fn>zb3K)F<&(s za!po#Jh^<-)SV0z{*Lat4GWpyMU-A94)fc=za+%PP8kudRLD@_xI-CdXCRvLsObSL zq+l%lR5nz)!_?Rni7Z<=c9fPoVA1RA$nm^f=oPDb-uGg`k{i3OBrKXKM{724JBbwf zB)@F$fh9`tP1gKXhJeGg-ybLmJ5*{peC75x6^xlaBbZNNkB@vHM2CEYarHk(TJ=NN ztc}938_>{<*tG33-1E>9l@P#HG_57U#g0<^a#cTsrcE z*d}B36nKlYSIK}?3^*Sjubrq2Z290rcU`ZcS! z2+)mT2(zx&nJ?$q4FRZ?nIOyHcSqW-eAP7`5;;l>tcLIDe)n5Tf95Wk$%oeJ*U{G) z8cuGhcq}8CMOx7ic`N^NgZb0!2RtD{nsx4ZTKq+~DAyONROiD91QS`AfDj2<8rNzq_&{ znurX%GcQ=Sz0aR#)iYW$64vs_wHP1IZXUWgq4mY>j=Gq%r(Q&rzhNRfP~`IKQWUKi z1{%ZXMeE8zkbLiMJCfH(DOgQE_I<`N)eBGm)zygl^H_SxH0Ms|2gUJGSWV~8yhaez zo5`5o0QEijfI&Fyj%*3(Yd^FEACO&xkS@Pek1#)lX4u3~MvA!JR0q6Trik7RR=`In z-TZ3}(MS6Q(BpEB$j=*`R@eF0YqZn%X%#PwG9M~E{9Kry3{ebNIPM%Me8YZ;ZoS>lVBU8OE};B9*< ztNyjMwQ+L0(pm5Rbio~(KrlMA)D=L#fX&){p{KpI;P`b`nE;*>oKo$u zC;*OO5L7XJg#ldpUcF#@@}ej3`W&ELVDDW6GA)et2fQO+*|&ch8^w85^98R21NB+W-2_qF#`w}hf4G7-svP{ z9cv=EY>vQhu9Z=8L`F(qRM_NV#bg$xDjpSBl_FL$xTLO?v*OAvK zuuh9Uh|Z?S7X-SbZ9%3GqF@|6z!Edzp=f}$86qptQ-;3EQz-<6zQ90!Q~n?DLfG*6C4ljGDzZt6Ss1N&fDh2G-Y!|NA0 zX>s3rHQ;N`;0q~NB6tb_C5(Lwh3`y`)ys*N>;jL;ZNbY+>{Q0>?H_n@`Hzo>iJbW5 znkn@bItZ>>4`9oH90?YT{lEtg1?OHoX(qo2>9I8RE)>m|}T>N3dYkZK9G)t5vU&9ke!QA1dCvh6 z5SO}U-5V!&FEOy^$ews94i4xV0T3SGE(}Th0+Ui+r}@z^Y!rSNPo`$HSweyqBpSqm z*KnWwK0A;Y=V&4ut_eqXYzTSrr#ecqCF@m}Ha2}B9@7N^|1RPUxZ^=}?Eyc*uX<~z z@(NDqZv8?Gzrbo~=upYg7Ig%fsZw!EfVEa_3a@eY5L6`H{0I2UGcxMPy!p~atM_uW zt91o4UT0zY&Z;g8PM_O)&3f$^jGHP0D9S{Fwt){bjmV?hhygE)uGMCg%#4MnK!{V# zf@Avq>Qx*v*Nl#yJx@3PtpVpY_mO9${M~BP44-W4;aOS3VX?nFzXGo4x8m9~o}7r6 zvp(2~Qs`V_SS9z_o;2$S!hHeS~USURCuTONu4$`x&qL{^#i{4YCdu>>Zo zu0R}t7|O#e%ibsRyS)Ay>Q4{W0nb# z#OSAwoFuU*n1!9}F496bCw8rr$20o3b>)-rUn@fiJ&_2bi6qVVg$>+Un0m&LU^A_x zmtUOq5wPP0Z^)g~>TPvmrhUv9kh%O}<}_yYJ5wL3sh6fbD`d`jDIY_7#9LVEMitlJ zN&hISIJ}1OGC@wAj2pCeRm`m5{I=XlrNy4iB^lF}sn?&wX4fg{B-$w4?=@2(?lVr=^^Sqm~8<^!w;wJvd&1u&F zJFhQiG;i960MGO|WVIw%HUkY>;U};q99*N>z&96DWxCXI{zB&#VqBl24E!h^ytuEJ z3p==)J}RdC!HO~W!C6X74%8Il@|BaJQyH_b^;{Bh0(^?_h#-Os^5Zd(HZ}a}z#}@X zJ2ihe^y{N!d1fsep8qcVnc>JE+;5UTb7 zTsQo@y;h0eMxytf_3w$T3PQFr6J!A@!f6tM>Vz1$?(!AB!_G zSlVf+$LA8E_YhP2hni8`79Q+c^wlFdQ-x*|!D%}avPoX9q&=Fza6C^Pzu~cBuF5s;pl|ZzwKy*=4 z5K)9$%VBug+U*pxi|VZ0HAr3fq*d%g#|Ht{jw#sGkB8&qnJb0>q7W#OvuL_rL%PdN zm^`>B?n2`YNI3JOnF7iN1zmp*QWJi7rCZvQfYttrFLDa{HrJZotr!-|?i8Wdh!-8O z!Z+pWD|}`!a28<^kxh7)h7%eQi`jU@qQD+y3c#aGn9!0Ov7A)aS&t-Dy2%kH4Ff0f zD7Yr)n4tmVVFr@>XsYug06_d5UY>+NcQe3uJbdtJu^mc*YM+V#$(ih}6XJL0x?{YS z>FmaSJECctHBtxB#9`C9IN`Kdo~3SW^oDHOCc-!H8}DJUfPe2xc$$Zx(7 zfU2^YkpCV}0`1AmT@2q80u?!j8*ZwF1+a`OMax1?p~frgFÐt*LY$5Yne#bb~mZ z&AI2oYHO0?!f#U5;`ty^VM&Hv=0Rj+w+`Z}N2s%)>;}Ma;lb>FIv#^xF3*g8Z8I=_ z1ZnQ1Y%-@A!;fJaa7&G?6&OGdF;{(<_=|_XfKm?=Z4=JQ)0TO^UbXVBecI249B3_`8Knzc?(580p`f_aB4$XtcaGYRg2=De{)pz7=t zCn8@-RjZwYh>48xU!rZ%0Le@Or6F`zSP}?+_m@$0`#2oc5|nN-Vm7y;7_MOq4RR1y z00DvKW^HWt)LZjZM~qew*X@y*gR_=BV3KN;)C#$brOMUicJ z47Z-_Mc;nMOg+)up>l$si(P2t6Hc~LDkv;zSMxBl%}6|eZ9?uwS7{uU8>+jLki`yaetifFiOr>+acY9CSk4c2bAwh2e@V z`-`!H+LZzasldLXOT5lF;2C1a90O->0?~z4enmj>Br*Svp;M>q^G+UNkU*=6l+!+2 zdxM6|P2leS(?tg7Y2lZkmS;LEbOk?sE9XueXk)UMGI%zBJNa6wY1~JkC>S=a?(nsl zIY=n_pasZW{t0aN>r%bLw0BMz*jj{QvDzUgeTxkT1&9W3ju}4sJ)4| z(khL$1idO zzbA}O9gKN=@}xxLT`vuPwi4(hnn6wSqF=PB0ER=?Myb4|vx2#{)KKjwq!N?yGNXUj z^8v6H6k<{wQr=)Y3tlwl%xeNlc)964%UF6Vj20AGnsyJaCutgBhwqw2?@7hYy&H1o z7IAwg-2w3?+NBoh0}uFpRAr*Uu;~*hs1MkYjKtI190)j znMGWa9rl7R?*7l10p!0j18D|yq;P;di>7@@hugn;3{os`=hu9G*)S1zK!XIP^UN9% z7;pioAkhL-C6bCjb@oJ=Q_OT#5_Hom(TFkiMfSw#&JVBQH~?gAq!|~Q=hJ-60>r%f z=0|oyM#Yr{>wrvkDzJYTA6OS~KCgiazhp^4@hnk}pgz^EHwEO3qS@%y;z-yYc4taa zpNNQ?BVr!eKr-s}3*`o%KVMQWK^oj=+3^(znnwUd%S!jPK!M(azxpjXfk+5Wpv_=g zGvx)TKADIEO7Yil+=X7$_I@Zo^u8F?K$3<7PEm}2m^0Xf)K3DkAr$uOh5X+1H-?K% zE;|N5>1ZdL#~eu0E0P79O%G(if{MCqTV()Dh43JVcQ$_fZUd2h1?iQioNC$}O17Cp zFc3WYZIT{oh5SCnvf6j=e4Y3ynp%%9;Wu7d`G_h6A2^$F9ppNC`Gx@nF^64=5CE(b zaz0l93Rmgwmq|cqgb|SI0G6n9^zJ84!x9l>oT|E#w*#)9va&~SH8WFq@Ls>0UHHPV zbRVW&3fRSWS=uADx(?V^;eb{8Ix^3f2H>a~d`GN0j zktW=cCIx+5TMc2_vF0BO(lIxnsojDWY!FXWUdcwmopoz0c{G|V*By&c`dFIl`Xiz_ z=bSSVXF(x`QhuV1^K_|V1K_+d?~M6JLBIa!giN%<-&2`~_sj@p^*Ct+K4@<$*-mPEkLgQ0#mh#s)@XJOR11# z|JtYXm!R)dqszNO(0gg{kTUWcQ!N6uIX`H46z;?6KXe00TD(l!^bd73Uv9i61J(eT zK--DU1x{DabCd3xt?jA{@;u){6^7n@O8Vkd1^HH7+jx~9{V3QcLD%MIL`fjgiwcs% zT#&ka08fbNDRj&Eu%}WLmc0^d&1H2a@YewhcA+(sP(bL}yHOMpZSqdC*76~`T4`$p zH16wn5gyRixn9OrCVl7#bdWD>BmfCM;NT)5A(%B9B!y;!$-muad-e+#ifnQn`gu^GZH%0M_ zTYWFkF%ree1;0Lh`3PqQK@HctTwJecuSZy%smeshmDJ1=T(i-uCGb^Ua*)JaZ5c0; zY^VP&$)Y@nbUy<|5!%{Y(KbA~n|?Y}s~qsu^LF{{w1AWQ19xDuaR(Ku^$SFmVTRW2 zJ?DA_i|TNIS$~a5n{35?;nLk{X}kw5NSgYGzQ}zSJy7K8BF>9@EL()lcjT&UfV!xO zB6{OvP^6M4*M8-x>d!Zz)js_Qy%gh?fvQ$YJz@}mT*)8!V*`l!BDzBWNtfVL>u#Ul zi^cx)aghV+rWgLKq~O`3PYU(5#e;*}3n*e) zG>$+@kd&@8t>M*ZZYBYGdEJGIt>2UCxg`_ech9dAz$ypgBQIGVI!UZT}ro(Xu=cmHyHBd3^58Q;{!23TMX31*oRcbT}?J0l!bnNQWx?l%h z5aJGC>=!~;Wb4mcK_*`-P5_-(dK89q-M1*B7a{CeKX@^U1=~Je)lzoVw5+-e3Yb@I zv0@7pxFXt-kS6|UlB%y?YnO-x3IhMP4A+;L7&=*sPFy39ETfhPc%lswH72%ceLE%m z)GD3(;{!%!ke2wjQ>&S_*#13UHIG6EUQ~n#G=qS=o8jA@!WoUj!k5O@0=q&kI1e_l zu;27UL{%XJyz&c%TuMx9KPL;BgzP9y)$HpCwOKFGFi z_+Wmc>m3{z@%>ZMY8UPV6aj&Lzc~|(J5p8gN^r@FWEEaH4T=YQ?3fxy`SJlMjJYo} ziK`COOYiswo7wrfu^o9v`vI)BNWsS`%`^V>`~#bgZ7}Lm5m|E+rVo%mMpTZ_+whW@ z0J{wswS9AhG8{hY5iX3RRiG#ve0k5rP(}5Nn;rW5>#*L>4`NgNZ7@$N!ya?ISaM8! z@8b1PQe-+SIhBtrMU3c;WeExvdrOGHbtrWUW(_|K3aaoN|9P&Mij0p1&tDq7cO505 z*(v%S`~jHWpk6mEwR64bx;wCLHV@)hlp$yyB#=F+Ff<50BhCDKwfu8@udjNJD$J!r zPthrv>e)P6tYzxmbR2-5f;a!Fk)AbQY8v{=DeN$KV@QA-NMEy1fjZZ~m2!RrF2JSL zNvFN2LMCLj4#LJ~Ej#zK5g39{4s(k6K|X`RyP%yXcF^C+*GdoRAFX=1*ynt`NUt56N!nuE(h!=x8MyG1oD}C zy6oLt=^Pnu1kxp>n~CPdX`u9C@;J8P^C)^Tid#A|)z>n6-9ZSddt;SyA|Ae<6Rfw- zfj?do7nG;eAOTJ>u?4~_m005D#l`9fF=%m~zJhB7(?eDgF$ujcy*il@UPkl_B`}Z38pQC;P+Zh+nb+Ka9tHp}O)wcPT?*nvXQB zZJnD(z`in_G6zEURP2L971TF+J%#^2WA6byJzm$1ZlNiN9R*hfyK<39GHI@qBs0m( zq}Q1Q5Hcx~NhW2I$s`knMNm;`!c{4P2>Oc1Qlu>EB8yTM8z>;KFCZd{6p{AA()a!5 z(R9zZ`4pi2?z62ppJBD+UoxDXgY-ezn43YRg~_?~y3X7l z*HI~m(Pkh$`i#(;H5<=cxODrBl2;Nyw;^xMslGW47#RR1``G$T-ZFKy1~f0ZB8TjN zRM;AC^l{#cn`IgH^Ch{{^EHJDtOO|gHOMDX2ljMs6ZR~Qd&$ugcGg)wxk0UuK{J0? zqdsDc;nZYDC8v{Ktwe=N+iV;kJ?_yUwFH-+_svF{k7EarGMt(~$w}T#qN1q!njo28 zM`ZFqMG}bz5V;B@$bUI|x+3ObgAwkzkq_}OnaobX#?$$uByW_3EiT;#7}wb*%q|xe znEWU=EG$@iM+FVOP@pdZKS5i}5&M=I$9HiWv+omoA~7FdwcZl?0hzSidP|XoR^Lw4 z5!2>D#Rv3A{2HPCNwsK8IzAEeCc*^2P|;)VcTuFo7TwadULIW?Wq{_W;K3%R4urSm zS^HdW#}$oXfT@A-j%}WJSTJ~asG3GB?JfOb7=tv>UK(FEj*d{9A*eBaZ)*#W>He14 z>^BC7EEY$ANNL-$bwFQWwxb7Q8K3HXG^pkgkPo|njAQb0tmSg+!N}#e-cWNm2OWiV zwc(iKVSCUp*)5RC5h7G;>dw@{QCFfpCDkP$j`oVdOCZn@6^M&YtX|J`;8F3;TI~+$ zHXy1*XuQ6W)UXG}1#A#)=PBkKzf7wD6)S0L7EANY(Dt?x>njx06*)@*DJWhx@_Id| zrGN>m3lfcs4KHISW1*n%tFt4eLkmQvh^foAC)RC_6)mq3T8B!*iDt&Za)51chu~yN z0Zn{=!b4=L!Fx~Xc2pbTtF2Pk-6m`z%INXRCE-%C5x{oQLxc*JPR4ScPR8pi0_T9l zI(>1PNt1ZL=@+1pv5=0=1R$^Y1{tKCCSBBifIYk?SAJRTI5?J)$}6TM@^+lshXg_e zP10+zo^8=fE&=#YlbYoL8t~J@WNVfZtVfLsP|zYm5zhrMn=P=`S==#86!fCpLRggK zsg|Wn2_Y|g7_Vo2n37KAtFfOAtR1ejo$XUhiRzwUM95(5$@wmwSyQq^4{H^Ox(1A4 zR&`5yi!*26UxrIfRGOUI?NKjW=IXwjl4zH2l|60r%k3C`azm(H>ZSq27W|KNnMJ`s zQ!O8ixu210NC_fkt_TV^eDOLuP`b0Vy7{;~O^|I2G~;uiqblPg(0Fjzbb}ky*)$&0 zSpgdN7PeiTb_b3q%{95F%-AAm9W=Tec)p|lyLF>^X>JsQvSoQ)NRtBMR|%qS7%5)| zQ3AW*e70_k51QVXhC8iv$M}~(lD)}AuIdBZciKnok#RC*;@nj7AGd3 z?S1E@hT=SAv$;O0SM7!$f<1N`0@MoY)tHbBl!Q>JD!Ra5OQQt$6jZflYPX3m8%OGF zVB&WBvf_CVBI85gyvyT2(+^cOQECuZ6QTS#$Mr0W#)Y$WWZcYy)fP6TWuuR4H8S3Q zw>&yDvM`CAy%>{4u2*}yg??)l@-@^m$0ZEj&v%P`@9}}KJ%9|-uJ*dvSSd}@>0?{T z5!>_ZA{*x(D^-K`GD8I{KkX^WVce_f8H-YCjf1LChHW(r9y+rJ`ig#8fsCQg3|a!j zxcArSG=#P7-~^T&uYq?2ul3wZu6p6-g!Kzzz)$1r&hsV=&*?zVTm%*@zHdf2H9xN{ z7O##pXbduAHGFDOWqBFb{}_JmjJ-w=Pg`b`Ebi7hqCBjx?&uHZ1 zL%S~%V8HWxazkfZ1{s|3zH)fjvSACsVy|}xbDKuJqFv61FgcYV8y8dqAWRn9cnM_# z-ZdQMk+04kLi{a0AfO1Oj;H($5!-gCQ>U$rG;^p z3F>8y*1|P2MN3gmY&Jj*DW8ssgs3O@yIi2l{ zMRnmd`LsK7&XbcvG=cbjUx^{vp3Y0~rLRS?>&?BB8gC(ZS1d zQCLe;$c+TX(>yO4N|C`g;u1z4D@Cd;(>XJkBN(F34)`(OvAHcQ0WY`W_d6vUM;r;* z?5=Q@+btu;Fn;(k6kr!dDv+UN$3!pbfG!FkV}$Y%f0!NeL^mXB9Ws#_YDGs#^TIAHH+tLvFG~nHKV{U0ghd4AiV2f!v6g!akm{`gAY?3W( z@hlHD;)|^ya1?|-L)OOb(SV@WQ9@Dexs5h7G-CKY+OL|$4%wn7nLgA2wk3dnk`-nX z9{4`@dq!n@B~0K_z*kaOxWs9dtT2_cRVI@=(Hu7O`9$=Jvh)wjP8&Cha~gmKRviYC z<}AkruwcY?T!y&5K$aBSz(bgGB%(&>TWBbqmSa_gPpW`(7Q&XF?xMaEchj}1f|jPM zSq2CL4pukK0c4{%vfqpNZlTVP?H(o$D0{L>S+5+2M;Qm^WrEX9!PWVi?ThqqoG)i_ z5tu^^WujH$P4!Cx{Zv)-oCr9(PTs~g>xPE_v?vT>zJd$JZX%p=mfVFSiil2_93Yb# z&RKnXl3+0KEi_CotXnJkHU#xD1fgxAqxtdX$&QGK&_!o}Lp`1oE}=kFg+s=(1$RUd z`f?h?RVI4ln46L-fHeDz6wUp949kisNeIcFk+TRoWb82+Y-nypc*HEl&R9)^yqoRp ziFqm3YY;UEhccc6niKAje>AfTCIUjrSIh&Uxa%6k(T-_e99vr@$Q^x2KtI&QPdwDq=*hjP6Opf`M}^}z31M_|q|NXOWu zj>cml>;!*XV9v$BFS#Pr2q>%es4VN;baQ8~0`YLmUtHbGXr=Kcu_+up_H3HYI*Kw? zCDZWAEWcRtINc(%J1A(;G7bdaYNbN#zoH_NJbLJStD{dYiLossB|rf#vKz< z3v|j3D<#3J+6Q)n=f)r#7Za0HI3s}`U2zhWe}K%(NFa5xC-jR9f}*q~WboN6Kr0!{ zW4EA2+ivZ4m~WSHT3i9?fu0%#>Zw#P@5e=EGX>nuHf|-X>9BstA?3mtOvq3B9aU%A zxrq`zzEr|=HMb52yGkI+6I-lii}=u~Ra_>6TSVlQMsg%zdFveOLjP~Dju>q1>@>opq%785Gv0?Pq;Zt(}#+GHnQpwJVyL&DRv za_6X^FuvCzhKfO=<0H+I2$cmAbdx}UHWjPz`+f>azJLeadJrgxThnj$OgpvK3EQEQ zvR!TPn4*AVyqpS)kq*lXwnCVDaSgxDx8izo=CG~MblYeyGlO(d7W1ClHVAfWroX7X`w35-KI?=gYiSVolMY8{V}7f;%F%x<)W07@OQ z3GErA1%-=BifG!b!)g?}<}EeHj*7qGMzs#U-ft_Ky2PX;AybN+u{N^@9my4dvqa_8 zEc9?c$K&HfN{7=lUQgq6{2WJ1k%|DYUMQSl3gVU0d9?CJABOl4NM6A}~nC=|dnUG8e#DL$0a|ZMsCpQa& zJ=UZGnnJa>9iYfm1b~uQfQ;)&M5F^`dFC6S4 z&Mt|vhfK$akr{MrCY3B0rmkYKrAN6anmZdl2{Cohh?XH78yUI0ra%#M1Obd|7EI*1 z0Jx5lrGa8=w;&g_J;-U(C_FN+igD|m9ob#kF+_AW=3R#48f+<;7&gDjjEc#?ib4x%yQ8kr_P~AWhwsloG(}e+x`HW&oY!JcUv|tnGsOkEH~76)a>WYM*EEinT?2f77pf_+b+AF}aI4O@#~Uv| zdzfyVi#v-hA~FJt_+oDv?j^~_iN8-)@Q`95m>BK6aeLySF7+KZrAgHC+V0~l8X2h$0 z^ROgCJY$rq z0GMP&O8iVlj`(G1O)Uc(L(5P(L6*llB;e3y+vX%4PJ{$Xg^Y5-aUK)_RH`9iIq;`R zn~&-Fau)W;W_migc0@r&=&ldpQ0)SVLWA*uJ7*X0O@f6}!H2+#YrJOMrfFx+&tP)} z(K&8Mm@ZS!l3a6|^B6af=s3ZTy1tBxfz{iPpQ-xN?D~0|mCVs#SN48sfx=`btZ*?G zd&x8SWq}`y1c6TfY-?dhLH0wlTmdt1&(>jPVcP2Abo-0R&h=J?;FVn*5@iP}IMfUx z8jGJ~@YsND0ixc>$0G>jp2$)impEl|5-5Y?s3T#9rZo|~?E)T-iyg;}IgvWSP6TgG zc?kv_+Y(uWEGxdnjy-lJn-twxB|)o0KEIf$)!D0R0Im+M6QD?gMKf69SAc?^j>`JH zI)x>gUu@$32vmIr|8Ar2dPa1m#W{s!5n5A!sHv|&x5;#f-O5|ex%gbD>TnVa?*RRb0nfVi7EyHs$nA4FRwP;ICJcs2 z<9#R7T?^FGnIuCO)HBvBQ1GTANw7hdSY>r~X|5Fzdr;)ba0Eyu#zm_gg!EGssPSiz zm4WiRQDSR)GB>FmRH0BvP2sq-)m!DkA0I6%gp9tKr_ay}+L@yjAs1TKm zFx8ZTX;;wjhTY**jF~b&DYVtGnd~tFMjD%eVmTm28@yxqdLp+6P@jjqH(k|p(%UDq zEv5so1s`7uf0l8!J}XVJ;|~~_Zl?Y2QV@$-hls^gJf&Hci($VIF6hMX39?x0lCNZ0 z>*4KKZ!E?p7+;#mc+v}ISOA}+P%&a(gJ3Jwjl&$O|K?>vkX6WH1Osq~9JOR3202cB z1{ANQbjZ%y0@ER=-5<7y9)VP03IdG^u?{iP@Dnc7^R8{yNf#|TE>+@;8SC!KLTxs> za!QwsJI$3LRd>4`e4e0L?o=m$$LB`!)M^2hc(hmIq1b2Vr8#SaQpd!0kE@epp(}}~ zmwQn;>@qxUWo|Py2xTq{in$S@?Rf@+b55C+AMeu=LyqLw0^%h)a0IbFU~|^Zp}+=> zdDkoQqr8zMVY%CetMa15ooAaUy268>&`L~P6tAIGXy7S>TlixkZZpHAnwIBhyV3d0 ze$J_TEm$>DwyEH*vWsZxELZYW(q{l!$ojplPm?9>x(6URXouL4#zD}NYaK#CLI~|s zCmwdK0w^(KOYKb_^#40$>2%;qF*(C)Sy*$75ZdW|#SRebpD&d%FsE6vL$dY7-fIVY zuPA02+f8CDeM(-J%1p!}?8z+1B~Yq&!r2r*IUjU9;PcIojPN~En+IqlK_tho`MC&!I%Nb?DT zRq{4f_Y5H{b`?F1={1D?Dt!57sRN1v99&KTH`jj&!d&@j6LNnQpBME@2wlpu-KY#KSNw$Q*jcaB{*4)SQ|V*lA2y zA0svsbewA@^mQ-uajn4o)lnN)kmj;IRM>zNuF);E0zFJt%@#Wn@YJ}N@kp?VL7Z7z z0JJ$7&a^wRC4O9j@;XFTXS5+8981P5s#;Ov1~C{f2NrM&)EV7K_GY6FiVYIuj8sa| z*39Kvg+y{lbWkI1Q1?N78@LZhz#~LrvUplCnljy zqblbCrye{{pdzt?8Eu^x^@2>-U7~Z&#@u0$m+Ux2N`d=ouaj}I_|bYHRI?iGF8jtT zJ!j{SrPaeB1({uaiIG=vkXmXmQgnW;+TD`gKCCseGxw?bcuxiHr(#J3Cz_zLI~MkFFJ5- zAq82%XU#YVCrx~sNIoAl8U$O4G3HnBhxH*q1Vax3r1Sb^)lj&Cv2f*i)Aj+1-R(vx z2zGBrm>>3+WyY+?Gd!@0w7x*UpVRKiDV}0SXd?Rn=r|E3-LgRFmmam%#dC zcj-+s0bwkO zp%i)2GSB*{nZOy;G6y>&t7c6D-ym!GbKVv)G}sp!ExD6YqbI#|Of3jQk0C~j*e zoY|^lZ!YyVTkX;_cDSOffh^!;k*O*9dNqTt$0YHcoNyQE2*+a+AI2(S?0fm5 zt8u9pxeOxvgOOIg!fG>G9)Gqy5K(E-XgPL{dutF+pO%MBc$N^|4$uY4q_);B zX8HmF&;p3|fa!>ySkWQ?}LZnr&ID*mI+S7r%|qoAZkG?Wwhwcbznu<{j<;>_}c%>t3^>^9C3!u&~;FsZ^fc zmH^zh!TM?(2A7Qv*=>d$r_f4>H?X>y7ZTzbHe&o)G~4$1!A(vcbwwN0-W7*=w>3$Y z#t&X>GiA;AG4N*AIlY>wm2%=eMgW9S9V*0yg#eakb)98*8$CNWNXoCflI?f%&IT1M z6<`5NlM;}^6he+K;66g?Z*Q0_2YQUV({$+jHZ0m0I{Wl_xf1qAc064`0_@jowa~cp zZl=SP15`>nzpye0T+YjVW%~SLWTyfsPoC=@(!w4E9Z@IPr&*659h0$-oU)q9K;XCf zGvtw49w3WLjK!-8pG(Hy&6awsaD1L)dVdTiC z`qVaCu*xZpeys*e>;Qq2u7%?^(qX{dybM^t)n^}30~BO-?A@B4O-LjMoRZ%z(r(os zK}JV~7Ju1=EDf~c8ZOg?SZmErE#pgm z1XiqaT-dy5HQEpvh32&=nAB{y&<+od;~Y;Xc3P#=80n$;Ur<}9w1{h5a?y=`GCFz- zti4OCU2-$SwJe>Nq$Ve0;D640aJeFDCv8=^Wyuo{+KBcXU5O#SNwknytN@nR@!?`+ zKzD1PY=`ptvd=uUNVGEt#APx~HIN~mk=rGh?4%J;>d^|KI8ZuIl^M8`;b>~?V|04S zbCm+7$2o-J3#V9>SvOafJp+6fvoS()3Iilzw*JH+x&vAqOi7L%zb5QVpMoqNJFW$u z;ihMR`~Y)P90XT$mmzT=k19gJlK^KtExYt+0YYsIrYGGocH44+%%tUEj4RDGSW#)} zy3XumT2Mo~Z})~c%HO~SJ4S{evOa9;?FK*>89fGq0mbjnrz6${9U5Z95_VpGrp@ho z%Tx8fM)!~gKol^q^k5=|@n|H(o^k7hlrDz>I6pX|1i3F0HC)($rDFAo2ReKK>Ib<5 zf_ds`6QxT!BT}tCt>uZuRC-nc3n-c-bA172r%Z>axgY0-vDc?IJ`O2$B}_0VSBd$W zK7#mT;mTeFY8xv6{7%R5Aef-=meBK(= z8c~OJ=Mv?4bA;>8;4Y4UU=GQQQI?}g=Zak}I^W$L$^!IcQW!fYf3?--$A~uS6@cjy zVR|CK3Cn{NY&mP-&XgNI&@_9;t`%utn1NR?^)`e|uc#uW$DoJ|$Bsj7?6i&_vyiKG zSp%MzPsC=G9gU4J->Kzfj6Mt`Y#{obto_T%m60LM2xO9bix>D1f-0U@xnmH8 zF9o3MLWi%asshAFw_A4!SF5&>q6-~B*-m>9bH&eC4*cNbufj|ZdD(w;APnnpja1tVwyXS?wx)PFtcDssyx5IByefip^*;E z>6&Rk_Yd{zae7>l|DjFT+3N{tKP2H4$^R0hC;vBp6?JM)3OzvG&e7R z_Q8S)3IzHL9?9l*(;MtL=3B5CV?4W9ET`IbzTX2!ft_@7n638Kj$kj)&X{SrJ@NPr8@4C&g!;HDI~|KUw5#)vSavRP?-cxw16PI zW-wMwD>-zr7M65{oq0MZc$G?*V~eof!~_-(31BwH;6iTUAF;1URUEtKtl^joAqIKi ziG2kh-IZuOisbRZ6l;|;KJGDQbA*4xb3FDOFgPeX2N7K0^+0L5m6~ijnVE|t&Orz= zW?=CAT7wcGCdBY7m7X+kK3aj^REEbWhzNP-s?TTpk}rL9ZV#Tg1X*!A#G5>hR0?XP+j1cvOC3-!+0K+ z#B2qjkC%{%Sbn-1Fa46riLF*LtR`|U^m{j}ohh|MIg3a0GbrZ|np2W1FyjC$E<7+5 z!_cTWZkg8RCecBSpOT<#2l_<03aHkJg-5zEEAxyj32RrCAQQ_pXYhfO129LMJGJeR z%3il=BtZ{tU%DeZoEL&^Jz%*9Sp(bZ(DAM1)zJajrCJo@ezohk1rKaW7IZRqHt;`L z)HISJf{Dy?Vw@eU@w-|pgqz_2hgvZc)sI)wo`^k3T+mZDz@?Od8FG%sK{qj6(hz2I zb-4>6%aK@uDfb7KNva$QYBI;iCR{XlA#kg;yssAS1^xp9ve5fU$D781Tn}(*G4w>v zdWswe(?*Gy@z9~h?h#+Zb`o%nU8bH=%3OpY$|v;F)t z3Z4@7kP*V+kV-no3){zc4g!cMFUmo}%0wh@T!9^Nz49CMtb(3v1HdWV!-_iIt^%^2 zOKze9(?SslIZ5_2zd%t6NW}22!Uq}-~Ccg>Y0}l6ClTmM6jLLICvV?C-4*zhNF7PXq=VI!~nbEMYJlC z=~$cs#bfyQ#HkcE=|Y(x$8)#m9_@sXF^VkWVJ)=M7_!K%w#^PJplzM=oILNti_NL2 z1}_gQk>O|Tp~T=&8_Eb)!g_AhrLeYXPw-3y^vYe(S?DYP5}xnQjEZL4J)|(O-_KoXVdpfd9LR>9du=zb^MFR{hgQZCf2(xSPvPF4r4JFKrT+Q$R(klaSgjV9)76Js8xSQVy-NUOsVQR53EhJ|9ia z2TNS)*nfiJC~X5ESBI3t z0>6G!7RICmoM|Ic z?xkW8TH$l>MKv(7JOJ8ys&D1I0v_4Yuuf@|EveagJ8Q~VaRJ&`1&om+D!8w@5KkPd z6PSN)G-$4w6UopZ3%o#u@suJqicdkmc35Ug?o0C!5j}N69D+`^_HhF8J?tL*4IJCZ z-B9-G-UZ7hd$Gu@^)~It#?XP?(?!hAHYrsd&VXDjU={}xMN<89!8sAujEEt=5n0d+ zh&d@-?AO4bbD;5GLi@%9ddGO92?OfAJjXP68`tC7q7d;p*n+VAnh*+?2e@N)>=LPi z+abXcuWhM~OR-RdqKy=58*CRu(s~L@BAXaVTh`=G&gW_j395mV4lH2oh6SKWsn6{o zDOUpt?4y7yL5K1fD)0qeA*c{>9OoxDof&dpbpsj$ADjiDOAw~`AeS}&=Ej06|tW^7FjTa z0|hV@3!#@x`iG6@H-uIl1SU?daHFra3-_8Cu@P>Ch0CODJ8&)3uwa(9X2+`iH|l ziVaffz`TN4=`HZz)fxC$K%b;lV6m8BTPPZ0l(q487nRvwlFA3Tt1WebZw0(lwFG)e zmajM6X;2&h z{3&;F)&&eSmx^8y#0ob~R|*J656mt1psxXz;e-HaB-Ne$r+kL1sAUJJANC!ZIjGuse_rjGL>1@LT37C2X8`Ad6dR#5r zAr2WLLni#%+!hEt0*)bQ)B|UKBY;153)VSy8V7(sj&Ywkfqs)e8a1-&2^g((c6Pu$ z)iUsOCkKsy(q_j&B^IPlNKb+D~#($XBigI_3H z=&!J@&qP7sN4Ol&0|nAL_)uf{u+BZgF+kGra%v+FkgDLoh(oNC`iuN{wg4Bzhj6Ma z9@c1hw0wk0qP0>uxCB0ysoAD*J5!*B_RvK&l?ABjxGskex`@DaI|srpF$Y{~%8;zG zF_<`09Hc!~Au;};Sc;x|zAzhsi^8Lw$+5AYox2^R#@d1}SD+$u+011U-Zn`bEhx6eVo5m^;rlc`8I{0D}FNXf+%kPqLkCqW8w zwE)Q&BI2V(JzbE}I1;q~Hs5MNSgx2udDtcQbHz})c@JL{Cc)Rqec=}TXYPxEUb}jQ zZ?$Ma1Z`2w57foW3o~}axE|!*;*zshB_)`G`Jp#zVY4R3CK;mxEQQtbgMNfO$xNlq z1`IO%HUs|^vc;e=FkU41i*jH;UxA-CCB{`Sv8+v_v2k`5BG92_2JpOs>UOq{IF~Xk zI68yofbCS`ZpTF8&an~^cqx1^IG4RTHIFA;JFIpGqV7XrS{39G+~gr=Rw}Tq-;))c zWC95kfN1nR;-Nml?;<|G3{!6{jNg4|@FIya5VHV@Ckb6QXA{)`>NGbpbo8i!xoS{Xxt3jcDL-^|)y1Q^#qd-YwP1Y{3XMxg7!9&rJ2ks-8j;qmS z+{}9{XWX9SdVwmf82tz*_feE@#rGUfjxd@WJj#?HEu$L zIgXgwCC+qdN=DW}sr+V=)UIJE_7dU*J_D%UqDpnn?Zp8QdVnGWUBzsrsoW@31PMNK zA?tCqqt{}b?iPsXgLQy`tCTj~uBK####?y>nAxRGuJ?(KA9o9KHc^)Zif|Y-Ow7Su zNgPyN>zyp3IV;0f#`$Me-fugIba)CIk(BFPSie?K%@2YkAJV>$TkzzNS$6 z=}0Gl-(vJ7c16ApzS>Ao2YcmRbTq(r54QO$9rJ`5NWPSqy?fBl`_oLcCbB)nJLTf= zdESCJN1k*W!r^MlEe!;!=&1#s9xEjU>-P;PtX7gTv*sYLbPP+YugAg}0RZ_FD(lb= zk)vQsTAZ0|W?={`%oa>o8Mj+44;)yXCT*NcwFcNqMLiiH&I;d<3N;N`=2fnxtT*fw zpfFJ8v@8I(Z@|-_#SK2s`nVJ}&f>HhI)~TvWZv`Ogc}+-qlj{`QxkKif_SApyG){` zG$eZkBNzQ(DGL^ih9+Sh6_d!#R zU9kKLntTIjzG>$8AF$ON#OkD+79j`Jp1GiM{ZOxF3^Z4p0Bn`vn9MxHfIeV!Xj)50 zeM+n`P~fh*>G)xpztX1NMq34g&h&@Y+#cG);)sH>SYW6r!u7P-g4x-2$4vE<8Pbl^ z2DR$PWfojipqjwus;Way3(iEkFs5HS+*WnC&Fo^@&Pb)~aHLzxhhiQZ8FY=j>0&a5 ze*EdQ-_P8dhSsmFx$Afp+YY=-v?wxuxng_-)+Rm_Mr5)iZ`mA(@A~^hRrI< zp;cxfg9k?gb!N~eETQx>2fg<*lc?Ln!XA@jXR9o_3RO1M($QoDq`u}`Na=jg(rd|# z2&c;otfWhv0EadKNDrg^4}WZ^x-$-Atz@NkgaTX@C!#?&OKA3SM@~k~drONm;jk7? zE2lmR)O-RC4vPtY;-|4II)YK{_t}LkGEV6T<{$(su)Sv+L#Ny-jzTLsO{SV}pSwDS zkg%G{3meU(lH*j|&5AmLR$xJk>Ns7*)KYe)h1L|NQ6&Io#x>1g&;x#^;zh#i^Yu9? zYC*SfdWq+|sUC&|IR(CQJ`PyUoxu=hE6Rty;slJKV9{bh&jU5W zjl{76wpkK~+Ui$e)Qkv)Qn)sJC00Q57K}ly!w&XaqjNJIq?C3@n}L$81E;u~1A1${ zoQRC&cW^{O&A9z2?8riBy6X;5RLFc%2Xn%yk5@z<@om3@69M^#N)Vj`=|D^B>|>M5FqsinEq7}_iaAiL1Z(NO+KkmJbnoffYHSwr30Ur(-RXQqkfU%g6->rY z401;f7vW{Vg+4<#Ts68C zg8R;oJXrHp38(p`XM5xM>7Q9M;3qrAo0!DWep$5N7>E>&7zgtWDW=eD<<@* z0S74x%$1X$j}f|bF{J4lUYrbqKZsxu$c=Saz%O*b?X8!)^CFZxWIbi3VddyxTPk%5 zcRCC_vWi>AY@`+_UF!9NLg z;7Ln@G(i@)J*g0hv%5OjF(k0zY&38{CUr9%<7-xByp3>77yejWJSmGM6ET-8ax4f7 zSOXYYEnqHsT9(Cbvrt_!S?4zAvLJu4&z2ct%1$n^(`D=63csMMRFjLvRALfss~F<3^`dcYGQeWiJ)rPv61=Eef;xbt zxg3no@vd>1i`dRRb;e}AfCtttDBc$3ELw82xF7fJmNT*?1Em!BosnlQm|&=!R*L!* z1|i>^0l6@NPYJ=y6^j>3k=Rp`0(AV8?KCPg8_sW`Fb5J6Vu#Utn6dD#rGUe*XKC0_ z&oWeUlO7+DKYV#PB#X%3s>1;=3WR{2Kv@eOKP-cjFf500n)=3|oKAv=hzzYpLBCz| zfZN~9YG^VEq2qylM zp&fi&zzFUzN9}@K0z70>!1Wb3)C%7?&h|XbFFTshr&$U9F7WDe(JzJ|CChfaroo^_$ZGYN3EBvF@M)l&$@d#Ki`~tHB%zIX z#j^pc$NOWgWiZm?JJCGrx~pk z%#^u6YZ~m8T2VYQ;}A*;4D``X66m8$K3dBf@ z7w3Gm#9de*Du~7Po|A3njg%;994_|@cYDq+a2pmK9Abc6zn$6(lm>rZkk=Jh(j=Yq ztS{lo4A#p~1WZ_+VXeN@Npb>4g{vWtHc|v@7i-Mt0Zs*6h;a)}vz9(mf@~{rX|rtw zs*9)Uh;BsJr%eSEyai?njZEF_2WAIMJn)c|r=bRF)&mRm6jnO7?!g-jR6092b{X%s z%#4pwXaXziqai)=teWIrUO>Oy%LI%I z8!}k$bGZb4h{Jvdq7ncRJ*XKK5rRoIA-YgR$QNZ{k5uhqjYv-0F1u~hou?sqBrlIC z7$diZ<`YT1#Vu1FhcH+X!avhQNs#k6w9>vPBJiOX2fa}*=0?)g3B(lZH1#}-VHr+j zamYRFF3d{EXbKqunKN~^^(4V&GL%jIIgn8Nw8O^~5n7;~3+ibz_zZ5WJcFz@boa*l zQS5!F?_YpXY^>A8tPKh-I4$O6qQlDu8T3)*so9HQ0lcdH0z1;I;GGE8H-YdaDiT^l z!xORwA1i!Qfc6X&4a%Ib16?6GAFNwcZx(1QS`68G+-&BTp<;rx-wst-VFfwoCX=pi zLFpsX7FK2(`$6DWkeh)*WRQtEv7)g)8so3ihM>w!wrw5^@a%y)Medy5+ocN|?=EYq zS{>I75@@KPJtC@9Rw-3DH^(#1Cm?=@ib|S{xi_3O3k8EwR)LqsUKHDo!SmPY5W}N6 z=mkMOj?JI_qS+l+U3K59K(ggE@^|m`V|V_&$9~7JU1L1-8y8pq_3%~lQH+-}6+Wym0`PWbW9@2Z-56sc0eEDCW@yNp`|Lpdg-rs%T(JyW;FVUXQ zKIzULP{@tP0Z_^3aB+)eK-ZjPSwhWCB#YhNyJ;SXQ=%4gs9n-96+G56Pg=driE z@!b!9;P9B=d|(kj@w*;)(}$jN{aZituYdhr-2+17lQ)0zb?$xYH~i8!z0`C*{Pger z>MKv5zT<{(`D}E{J%9RBuYb}zUianyxc>TodhE3~Z9epnJH8S;?=fF`(Q{94U*2}_ zpSgAQe_WaWx=ioY9{Ot!_~!e*_xI)By6!G-dOUsGpIra*s{f=v3U2+Muf2R0eb84w z{9EPu*S_P)#}B>vw!e86_KCmz(dRwsXCM6X55HY};Hy9VpgV74ws*bcVY~Y~H@^JOAAi?dzxT^;z5bK0;_gYm z@sZiPZhV$~}k`#kpBe^hzZo$r3H_x$p={@_hNb&LIC$9YZhh!?M4_S;|m z@Vg|d7Jmi*L}J8+0XthnLqE*pZ`nh^S|@>d%og9x4*A>af7|%dEfpY zH-GZM4|`?&M`|KwY!|9EeC%jfQRpmXDUoOgdW7C-3`KlkZ3z4?YG-s=G$ ztMV^=>VuEDE~CHnk>7pE{vE%6ey1*c@%BGlzpQw|_nx15#|@7?Ty-P!_`A?={FOI< z=lbFPJ6@=EfA@yxyp{UUJs$RF!>b?txqn#09U&j`>o2?Ij<5af^`^uM-t8YyI{sq7Ii9fpa`yO_Wn=hY@pYSW(FWuvhf~P)+_|U6ip1%J# z|1Wp@$y@Gt-OE1o_-}sjTW^y{DcvKLCv+U}m){^s{y{^e=Ti=7Aj z*7e;>um6XuI_evqb^RS*_`Z+(kGb(T9(46f<=_74+it(-TZVT(qWj_(zw!S&vHzc0 z{-|5O^IO8dp1~Eiq zKl#kcr?}o4a$p@c#?N2~l>2aTX{yU!f zkgt8@KELywPkHrkKj>$F>JL8l-q)0Oq3(3mH6Oa?gTL^okAL8nTmIo?kNg0zlAm(j zM{h;$xc|#uTi^Yc-}QX%u0Qzb)7yV<`QRu2`^UcjedR}PK7G{aUhzjS{K!|}KGJ`D z?N48W?Ed!I(G$LJ_=6XI?V9W0>3{yar~b)H9&z)#KKOy`Wq%g`+jE+y-2DD)?(xQ# z{IcRdP30yJKa`)>npAfAIo3$z1Pj6zy8hkEq6KHFGv6Qe>^te zo^a3hZtQ)2?8l#R)%QQjy!9#H`-zXg?3SBe{L(-Dt3UtCCw%{7p83J>+E-lttBd>J zDV<;Yp1=59^^jkDy?oa@UHi?Zwady4%sUu($l{4}?EB|K=}T z^Tt2<$P1qMzTdd>d!GN%*SGI}#J|1t{*Qa-gWjjU6aD!6Z+^`OtFD|*L~qH{GWN%y}s*BpN7Gi-qm`e z_`pB=m-{~9j%)7r^Cb0}$DZDK|7)-N=Z{_Us`)$YtCj2S{yW!t$X)*KSHENPx*vJ# zgKmHL%kD3DKl1TAEsj6^El;E0e3$z^{g=3@e%GJ8>7&>FB=f2_{pGFK-J2vL z{Q56_XcIi|CvW*m{jgW3zww6E`?}zAC!qF<<@4WA3KV|C_Nl4~P1DOe$afljojgi@VT@pEDc&Zda zGupH_);*}>Qasx%wcUP6<|Ov<8R}Y!6@>d#{qvR8NbDP7Nr7l#6qj8(JL>I_IPd-3 z7j0Vd%+VkEvs8}CUPN}I>s{GVkFs6o?Asi>yt4mxzf}I-&DUWSFM-wTDLy4exAQu%TZ!nuf%V*Cufbv3m_ac{3p&uEVE(IZuHTBQED367Zu^-x6g-zVv2|=pI@2TG9rg~nfCg~M-{j*`Pvv?WG0woK_ z29vNZQB-|lWJ3p@Z-k{Qu`}7d`oe;t(=l^Fl~D|MLSM;?!r3-%Pgs8;+M-Vm>^9XA zjINvkp^mbAHa#nZ$ANHc(tS`vHVDo%X;@&_nN&?HKdzmnY{mQ}`Ql`7=6_}l8Y;_i zHFT6_D!YtJ`!j0Oc3D!KHw!wsaLH_5psl~%CEJyF z=1ztmoZc=&{?KqC zjO)s!@*mH#EoamgtA>1(=wLSBqM|}yp0I*&V(8H1`?Q_6REkYXOl6#fQ4KK|eif{a z{ZQ;2LrI3>&rZS{ov)o6w>vx)L^14|GAUMAv!JHc_QuR@3<2e@G$`)b+7ep_)X9~l z<3m+AC)GNVdhU=69d+{UXsLB9*e?6|Tw3E=dH1bcqSr4x6^IyNDp@=&QhT166*+$^ zo3l$mAtMDA$HE7NpNFd@e!d*VA{bqoa;T2#rH@6*ZqRh9dt~0kcOG4^51tjs@}yr* z!f>Wv?!B_&x@}1F4w+1V%jcofp#K=W_nEiEQeHGVD;Cx(VDC5S+RZ-|`r{dAbm=`C z3RFlWMW^Ap39}*AHUr|VtX@V34v6g#1wlD=_P{_N>cPVf3XzdH1<@4G6!rcy2rl1% z#F%53h6{eMC z)QhCduWjkF|KC{ReDp_2sdo;CWP--2@@C7@Q<1(_wCMqPzt)c|kDe!9z8r$P z{rM_|7=~JbfCzPNgI7`&W|H4P#P?yD(4S+NnrnP(wyi0ev8v8Cfr@lehyuMB#^kg$ zX3L(x8b-z^j{7BB`r>8(4V4s6VRgtkn3p$;>bWlH8l(BQu8n%%Gz$bX0~pPa;hwGBTQo-ixrH z{H4YsG$60WMx@%G4qKulhch{F5{WTlRQIll<32bT46~JXiAV%tP|#+9JrZs;q|6Aa z7o*HUSANs0m>xDmmn7QP^perdZQ;wOf37zP{C%D*6#z&weu{Xr9CM!(U0%fac<3^x z4E3VEY6xLEbhB4-l}{@^yWv{F8w|USQh_7!uZ#t{j~mN6N#}mtBct$sI#0gO07Ye} zTZ1ZnLP_NkA%x-vg$jLXBqaA{5Q<%c0vYN;Q4kA}9AE)gtaRRDLouVtv5b(YB!|La zVda%5jbu^&tW_R3_2oncOAF=&=u1WVbUj_zxsm1oKN3H}UI7RlT$s=>ED7^3(z;sTc=?XbD zsZ_v_r{(5CZ$9Fj%?6EplZp1TXw%G!gBvTN8{s&L=Fm` z3^&r8RA*HAdn!W1?B4&rW^C%%u=ejC1}Ad83k5d$>XiCK;pqRmWi2d(bsw+Ujjhx=SFZ*lk{XNEk+hE&Nj#~5U{MR(CWbL2*aIf8Mq#8Giou6wTBlKdR76WZti;G^(KE`|mjr zBE9ky!NRE9`-V3g!`;U9RY;+lW|Uh)`@e2R?HyPXZ_bxrU7jqkQ~Ntjp#xU}8M~XH ztNKwJ@eu$Qg#b<&)bj_P9@vcAybVC?PkMb7KLwd_w}gz2?=|uTDEX&I=>-r6096lo zoSy8P&8K84MOrWavXr(Re)q8Y?DXiUzJ?1BFq=g`0ohTp%Ekl;gEz1E3|y1_N@d*o zB>w&sE-)VeqTol4J-f~QOttGKu-^6 z!`Ymin|2e<#O`cPHL+at*qD&mFzFlxTtO+5b7TMK1pd`dX5Xa@-7Mv(r(VmHa>p}$ zOL=bt?|Od$cLB=p!12bx_qs*)cqp{XO)OcR|qK#Iy?0ov&+K)I$>R) z3UC~qXWP5)OPW_87pRmcd=^uyrvs1hq|33xgiXVGb#hx>bHLFUcv?)rKMz;l|BB*7j$}jDuOI>9Ivi+ znv}FKQBUA2BcPPNQL()2@pozh9)(}BN?jcRgq8%)pmt8Rv&2#WZ?A#+=Z1OT&*j0I z+F?8eMWXE7e30AFU+=kYPtGJ-IacK`4wfqOA+y(ni>Tiog2XO60~UjHHSwMerC32Z zqVO!sh>3A_S{h*YHoine6^i37ug0khxZI!xw2}Voy&I&s*OWIR%r!h#M|RIhp3Tnr z0oG>Xsxs+1py}4f{12D&Jdd|K8~iAH%I97XNcO&!o7NPceafF3CPn4_NqqbX;tRvp z$h(KU*q?x`bd%rzNN*-iNM5#>=4YM#Qq+E^i^av@jTvK8jBUYRH_CwNpY>8(KupQH z#*-nQhYT@s>k;wX!`Y0pcukZM)pC5lg(^e~kUp*dIiN06B%CBbYU}s?l^$~@bnNp> z9t%JM?j;|99qNVrxs(-ihy|E8Q)(>Dd+vEHI;!$+ibwxaUq;Bo0-cX7v?+-B&KRv< zZMR9U_LfkIx8%PH3t0!GN@*vn-xo7d^m>%oRO#JU*AMoQ2ijVWGninC@p4k!ZT>pI zl@6lJQBCL{J}0}_ad4xECg0t==WwW0c^P39b@OZY1PFPuw+C!pXuv|N5w+|>R|f5% zA%DH^Y|Pmoz#8RijbSs#4eOb$d*E)*B*Csjr|?vq!$9`)hKY0XmqbP&4{9b{yG7aY z{OiP|G9>63_tu<)ybeN=Yr69jv1Md$R#30bdbCRZi^aQ*M3GC z&C+pt#HA}Pb;WYZn7pi1G}g7nrgd#9?0Po>aB;2ZB51^m~jS^6Fk>CD?=VzWg z{ZLZcxe-f?z@ev{9`uW>yJI88CacA5p6g==tkmB>zuK8Um&6O1qZ2eO^rHUzE|w>? z#OKjSbW|8b_$ua;sK?3!EwwB5!|$w0T2AX5Ml;dmotN#h#os3gn5@{Lhv-%Zaw78r zU^tWlLE!k51AD^`}Dgy956C745bwNjG9s@eyCoV zz0o$+{g18T76OEaaTmKiX#TzdQ8efN@N5KeES{r}(wxq}=J*)vlmtmx8Gi3aYBltU zPV}l&6y@bS_vL|+OeHcBgKP;rO%)coQMZ`p=0%64^cc^q3Dku%EoxUp-7N4?zRoQm zX%0uDi3y>%W9jZDU9&51O;vu!$5ZI8;W$=l7p)*V609j?{bh?0P&;OYO#6&1ir5k} zXP|XFU$rE72VL@JSr@0*SC!=6{WjW?6)pLhiLZ`{Bw=e^xPldCX5E_7ZlcvEClvzO$=mFAskCPpYS4kVQM&X2@S zk`Trl&lg~|^|oIid|AXy<`I<-lE~1mpF|)JFx)AKe)A~f#Fu*Mx|$&Cysh&Gs=}1E z2%%m1&O*=5OJfYfnE&a)qLG3qbLdbb!9d2U!5`uJ~#TMm@VN8l7=UzaeN&S&%+t$@yJ~ zaY(;S188tAUJ3m?@%4{utMjjh?W|1sZF+by1!qN8YU=PUUrc(1iB;CS2W`M-aGvvk$ImLCx3p6}B?SuRUC!xuF?i&~8`=OWeUkaKv z^KTS`+Nr0(bN~f6c2IGan{e;%Al&XCn2MFH7kb>EsZ^|mLOh}nj!Ad3L(DWQuHHF3 z2Dy`h6Bj?}Iqn!a=t2)$|1h=wVs|C!?;A(n^*V@kBxrt#8q*1@EcpX? zUgKn9f^2!QPH}X9Mu4^?W%fZnT6G<^QxoOY_#Xn5e+=M@==p2^hCa9njjh~Ik=4FZ z*Lps^)!%3@`vnY-QUa*bZjI+W{}(I)JhCkMXJH0c>H=_BRAUwiIDPnY0KC-adisa{ z3>aPvHfa6Le}5^$|8sR6I|uGq^)a4z^#=Xee-aCo<)F;}C?vQpfVh*F@FNfY_5i%Q z#JHj&%Q)!c9eRYBA|9JsE3nq*o*t|7QG@G2*xGZV?-eG#EEJq_~j@oZ~ zJ!kLu%d)4}nGxdn<>LC<^+hMFG`7UV;+2-*o2`kQ>0>(}bsu>~6~$-Eo8l8N{qF&Y&kOo{k%Va! z-B0Z0vw4CjsIo~a!Y@~N9v%tsT=klCnMyv(C1*ueq#CVq9BnlnRJ(6?>MqXl7Ao#f zRWGI;IPSg^M3`?D@3l5;j?P| zlY`gZvh}8*cXzaOeWi#UZiG?9*#Zd9$s#R3VDIl0)SvvEzh<=GW6vZJcv$Q=Xlo^r z2_9Z_SH8l(`-u|lr}>zF6sn~gzq|9`r&`_>Ut1d=#lEQUIXJLc56ewF9D=Cu>hbNI zZMv!lIZV!8YuHGNT(3&BHmKTJ8T~9Mw~yUq79Gi$c+*M0V?ils@c9ulbqXTnaLST0 zs&Raer?P~WuQt)~*QrsE%*dzeec8~pcahhRT;z+qro4w2Z~AR;j03<`=>FV4aH`xR z=qSjdWx=luP2Tr#_eBj1R0k2n@I7QS%kg&R2Ex6N2}3rxype9{Yb-r zb+7ja=vXpak!XtBtO&CkN)*PAa?fKi5x4O3IoY;ZxHj+UpsAJ(ZvUcVSnj`-nt={} z!?TYA&Dl=wcq;^8eCffSiCUK|S{|J%^R;Gy2Osx9j;pQG_5ObwOpLN$-lzeSJzhf1 z1n^D8_pfiKNolxe1h+-7>M!Gd0L0C>Be#Vi5DGg4o`^LR0Y$0*!Kojzt9!KcQ3$|6 zX1bYop078rC5bq!0o{LcW5**CtY)7)Fvuyp)OUKM=px<8L(#bmp6abspOJi152;*t zhGda_3$HlZ>y>XG7HD^0lXia_#U}G*%&V#AE?ZDsmGdx0(tXu&z5eHqpI@~(H$I~f zi+a~Bri;uRN(7E42F1Qg7w4sp0YZpUmcYsoJgVOk0ja?RhLK zRrdPWV!QqkaN#^p2}W$BYWg&E@=Zr$TGFl7unTKeEDh66j#AV1Y)vhg;(WL2kk;pe zreK7MY$@OR9Ai?qq+}m%wz{X*z>VrR8b^LJfvP4TP6B0|-=uy>7T8o?G|a!Yh%V!~ zpF3Lbv07bg0bDQA0P=m(_3E#6nMW*qd9(K|UaO<24Rt@!)VCmtO>Qb$eD0i4ooi0a z&d))uaS^Ps1v#Ve7S_U|qx#fGQCEN!Dt+hvc|ECd_3o&@~tSHZ7M2FN-9l^)N; zZTsMI#3*lY;Wz+m$E)fBb?y7X;=stV-hN5zpGmLtJ;tPKbNIMA$CN_QO@>~tmI{-w zo!j4ZA=zB+EvJ+{Yt7XR9Nnlbifs6B8-hdTYOtQA%bc!B^ZHpUY2_f0kFJIzQk^y& z9%{-BT-TGE?vNGj)^hOiV9{!Q5sHwPme)gE@R`wm@3;IAeCaTsr5|!sp5$?v!ir6* zI-|Iy1^F>&K&5Gq>&Q%!?#S5&2}>>X;Pz0%9OJ5xZT4rdNA4}%w&ODbO9EHu^VeF!#kF&!$+7|UAw}Q2 z6UP~BQXT;^#S{Kf@?5j^6yOx=gUC^2@`9u|$o;%j&!kDWuQEo)Stbph8&3%=q0Ph> zF3+4HOSS`M%wAo8_(EbXAz9j|98i!`ras?kS8)}2yQ|o*C6fR=`2j|RocmmR!|{*q zBH)$U0|e8-F5dk^d;I|t6@EY30u_EN^=l^N#Fu@_XORWZQnFl_ODBtN=@w}C;;{nE zEhR*b%N=3KF(-b29b(~V2Ja>6`_q=K_blCH2*K@jh&$x04yTfkX<)OBs_SSG**Y@CAR>Iz{o|$1(MfC2_2!kLW zvP*Y1Wkq!M6TMJ!v1uZ#7%<~bg#XrKB7fpcXfiZE)Qwb+Ze|6@PpNaG;wcxfbWl;iFpiK4spd~;~C$L`c(lE-+;rPWqM`-!x_WiZ#H;C zQ$tbh6i;5tocRk2vgGNGP_hu z0x`~ZiwiR>w^U;iN2as2RY4G_a&AikNmxTQ?Tb%xy+nr9j^iVE*0e=dKmo-tb5m4* z9QY5{{D_drV$Nx7Pr`l^w$z7mKe{+&699q^{&pW^%!9ovIBS}sPo5SrvK?4qLiq5H zde6ziAIn!KGlsMF;Y<@})x{Wimw87 zjj?CNc@^B$tCmI;zBvRpCmtmlyF|$#H|0wK@N1PExCUIWf;U6WeoWrYFNC z@S7k!-lwj(t&TXP%SfHvEa~=~9xbRogDQdQaG^-baQ!ovU($-ldj0r&!Cqn12p+S$wW>31agqWNJMmBTB@NDGJ!Kt(H%0PVi8jKQ^0WJJVB1pKRMhL5+;27L*z`Ai%5VKc>mG{?%tp~{3lg%INy$oD|yN%NB##M zfHwNY^9cbm6J71yA7n+AFQnX59uh(41tGTO$RFQ3b{hoZbyi9LyaQ=^E{@#)fT>2uI1?p90zRZ;eKZ5&lKHOvk9uF!D!1 z$mPj^X--oM$F(lMwc=)(Th+zZslJ}r((scIz7)ie57T}-h}hS3jY#XO zf3OK2I|4yq4{xe45r$k8;Kf|Bi+-MhKwR2wR*i2z2$LeiTzX;+0lTw`;QA96WYYGu z&{yffN3^BqR8>|uorH9lp^hMt1a^pH_@jyXW?>c=Fw6o1e=J>EZ~T`s!q>V5Tik6w zZUa*8)z_P%X+8@K&W%XmIbNB92qzDTVF5b6-C+ls}iCk$yX8rntXFu zeD&43YZlYSjYV3_kfkiHqtpp7hhu2vok5}s|B>DDoM_s=SrDD5pyelzS4JI z{L+AuM8eAtz-&3ts;wG+L5yS`S_HGqy2j|ymj6&$k^>d#ehCA3<;(P{9YJ>W@^?&J zf>)@Ev%* z7Rk3ONTTZx-EEpaF4yb8-9S>gh2@@%d3L*6xo14Pf!VNyrb`I?^~XQ=Tbx+V%Co`<)JX-&7E5WeIb7HY@Kl^)%|E5msjxdR*Vy87S0Zm$1CNSN5pWD4-^9{I?2yldxsmQ)E)9pc&vIu`!&(m@ zH{e(RC76*9i;Z#j*`Sa6%}r~+Anl-URK>1_(wtMRzZ=7PHf#h%d#sK2vooP0EZ%Tp zi}gxOm3u^@ttM+*?MDiid9n-R9nkfk&j_i7a2Vn9+U|L?V`f6{ z4zW()x!=)3>{`}~R*sm$C@e6L*T=#!2^Mmu8Y-(5%{qQhY)fEp*eLG-MZMLFVX?=1`y#dQ=lPaHc6fnlq+34?sr zUgeb+ZPxO-$4Cfy?o{`kZgcCoTg0Iis8e0dl?Vu!9oSLYLl4^wlrIsj6cULO%Hvxi zZrLY6A$GR%@2*>fGdgU;nKY$(iMWZnL?3-13=hpt8{nbcr(dNM8GL;yRcZ&Y=T3t5 zh5hX)ryO=j_zSW%j#QdSh)1E=Q26{&c1DD?Rk;1oR1mq**vu>Hq!U7wF5Lc$=>0q(h&_l~E>2c{$!yy%z1fI0Iu zuuSqi8&W1ykhwLZ11x%PR2e=#3S%ONC{+;Fn?(Z&o1La08pSJ52R|Q9Z7WA_?NOyQ5dNDWmEgvQy?3IG3i*qqDc6&<9s&X-0`G3HzX*ZuivP5ap>KdO@HA z8%CDTR=3HJ=v!{;OTl0IfZ@ttTsM%Cr}RGek5vgY%r+uz=pCm>LJ;1m8K(s%<=gnB zp5Ji?g=bbN@w-qD>_Q!{MIt=RR&pjok>4u{JUL8ztY2xXT|!B>65Ep^F$lh$1@Z<| z-N3!yaWh+k7)#^PhEfI6Oc@=&p?G^6E%A%CPef}RC_DGl!uxsA1lR4~xKId4!l|n! zkcQOq?EhNF$dPx(uFNb0PNW(%SXjBs5K+M9Z-s!hE{y`;%W6AOTjsU-qUa&9`!~rl zg;dF;^LIk(xa_TwpuCkFdR#P=LQswu>WkHdf5|Ra%Zs8oqU)$Rl`oUiBlpV$%o)h9 zJma>s*z9}l>sI>&>wqPE1rJQE`j|uuV(oTujZKKNO!#jTB+%Y!|g|tUj(xo z1|Y#hy|y!9^R-LO=mjz=qFQQg8K|>3R53tggw_#*Hgf2%1v~m%8hBuJc}2><%|vU@ zV6Nrs&wM5IKw16(zoNl+dw^CWcB7T4Ye{wZm|%{5n2vqzg%}h(%Sq$?m-6$W;Wo}s z-)*SA{o^0*M4QYcL>*sgSR#xE&yA5f>VA1Z1f&O@{f*%-zhwo9bckj(b?UqC7*)o9 zH8fY9V0XVDI?L`)4JsCD!9%}4d|wN#0miUt3Rng_h$JA*9#=G`P|aok{lPw&^7Tz} zWaZyKIli)?-;R?3l-ohUrKV0+gL;9^OGFl*p;pV5ES2hb`muhTr|TN{;Z~7a+MZNQIAV!P46B$*H$sxt>)g{rXbjxxOBNKupVy*I$9GwN5PpTcg-O|=C<=r&0t_au(p36PbeE*IP4nys=q25;P_>|sR%F* zj@x79uime{JWRxi&<|IIA5J*YSKB!Vwg?3^u#ZO!zh7_V$~7Do>Nc1JrmDKytGvyn z<$A)U5qYWhULN)Ry^;stm@my&A8ph)i#b(R0mIJMmV*Qq!^YRcpaZOv1i~8dYgWRO zpZuj8e&8tyJ~BGST?yFuzTH1d0{CtD@bl|rvs}3PTb?O@|L0*{ZB}GWx%Dya^-iM+ zXy=g;^{aL~T((&qD=cW}$6LPqke?s;c^}@d-p*qya>bQ<280lA^nG9#wVm+xv+Nl3 zG@m_L&$fuE+7?RdMIpX69ww19@1Re9jF)(-?XGz4|55?!k&Y~hjC3z+z@Eh0Z%V`O z49_UmWoU)9{TmT)JQ{Zd9Q5;|P5ch5My$3biMym)bd;E6cGY-2Mjwsj*xl##>1rLx z@6_1an7isZlo4*&*%4J}b#i#V1Y3Q^QoRA(XSz8Rj!RFU@3mGIi0s^YJWhK~>nxkuS*7wc4TPgT;Cj?|_ z!aVE;Mmq@?CcPKoepapiy!z?sCZ;>mOnqe`>n5e?ACnt&Eas;bo*-~1QL)mUWaU@C zn%Z9box*uaFVU#hgx99R%LIPwFQWqBNu3y=`L`VhEV}M7?JZrAgHX{9u$cvT7Xl*c z>vb`^6^f#WycByaVCA9Z*J;;n`T3+LN$l&2*IMPUk2r1x_}fVIO3H>Ewo?$CY(XD< zqiVmFnAmPMIrjhfPSLA8;p+B5tur?=PV)2qGYfUrN8Ne>R(WhejWQDM&#IQErxmd0 zbM6`}UW;;^3?n~Z;-xu%jpEyFgzqB1gRh&G-{G*HW>|Dfls(2vj)R8LNEF-71>B9>3%2%uCJlY` zJN9hwzT^EZ!?}V|@^`#qE}!(f@8wlM?r&ST%W#0g3N{=cC++<>CQ5@H04jIhE=|t) zbU4uaNMlt3Q48%#W7|#7b0qE}*y5@DD(mS{voW%II5M(R@4#YD&MGEb93Le-AhOz0 z6_YEs9L=XA)h4oc{T_Uy{gOhpL>6o5yYHVRE)ymc`9F#d*xtqJW-0MDRPB14ZZ3^^ zTr@2KlHExoE5DI`qP2!&OTeCse9cO-vvDK`pqqxV+wgHc2dv>TnV>88{+EEZ^+CaA z_x-Ut+Rx+n+crRVn7NU(MHnqNdOAX{U;&Bjzdu-H`~7yly(Skg`ac&B}Qo7PkFdVyYQ-U+VVw^6A6!agn) z&xNz=WV6}iH+cN2*zR%Su=Lk3_f6{}r-sR7nmYttx_du}H++w*X8kv}e7t^PLp%F~@>f1{psLf%7j`D#nJQ+vezSzi|+>)ZUE(0VLspwQEjGi&fj+SKs>rnUI1U@^^8ZiOeY6R}Ok$i1VBhYR|@s-%S5_tH=26EKbYv-fK zXSXv6q<4qz;|)`zz_!^mTuUkYE4^v!1DPmb3K3@UyJEl}3H(~A>uxoRPqkzZ+BIbl zEK6D|WPsH)Xf#_be(l}+;&RX!B-O1b4H}{XH5>CmJf3FrVVUrX=hEt}T?3uT4vEID z&!aBi`~U=r6zpLCZnk_hFB*5`7Hx*)qU7#M@Fuujztm;qXBc^y=*wT)AvXC~%ovB+ z`;z!IOe<-PqI5*h*h$2&`p1KR%ymP0O<0U#pLgI$I>jEw|EC)lcM)jbC3R+WeJr3AlEsQyy;@>t!ja zlRQkGHZk{GYxYan9MkVozPq@bD7N)-*oENmU~`x)==38g9k;nvw7_?C(cLf1qy|zhU*4WS+u9>!Qt%<7|Q7j&7SsrH?(V5-i;&I6U%ceYO96r(!*#*QDBE z!@%^LfO|UJXvHZPylkZYsUd)WK3(3L(uIKd!j@b%R0`j3rt#*MRpk3q&|VHm@wW!R z&F1MX3N6Re1$gh~=9YIqR8wa6LBsBd(Ru`L@u>ZT`xAux0Pv;u;jPEv%&w{lBB!9` z{6*6>?FX?$D~!?uay z+jAIUCc~Zn+R>8A)il9Ui_QFa_};j)TF=aizGw zhIHg23-FQe&Bw?cGXC11ixeCKYhtENxx@`LQ{#XVh-f)Id|%}>IfmW+(Jj~zUjhg_ ziZn2+1<9AZdc9vsATmIiBbK*Z|6&D}Vkw~FB~?>yT7-!>jt%w7o%wrRXxM$wt=L1D4xxN?WKG(21OH#sEqSk-^m~ zw(tw^n1=HmXC9gjz$@$H*(L)S zS+S&d`(opVl;BJse?gHJXDyKpp5NUNdyQhZB7T|;RW~)vh&D=m4V-Gk`A*9Xu=*S_ zu#xCZT=JTB{I1Ctf3Fc8jVtmITE#Y}dhHRqhYk(T0FRsdZoP-mkY$#iZRTody{q;; z+VLg$`zT7f!1B?<7{>2ORN~lclM;QTn>BCO{d7ZXHARhsAu9DbnLgsq4>re*&-4}q zOL-eSJ^GaIbvlmqW#}Gu#eu5sm-Z#g?l65d&wRL1oa3`x&Oa!@=e>Mpcr9?IcK~ay zM7WyxD3Tn?*ekvBkJS7Np$lHzwhUAKI%lM`XK2wLX5Uw7l>PT;chsF{S4Fq@myySt!>Cf)&H zeij5AtIPs7Yh(is#k=a3bCcJ=X%ilp4#VJE*(!%@y=P~V*&Kj-{J00yaw+6D|BeIT zYb*qkcwiuqMVLR&yo22m00%HEhVMik%w$?jGW!8v=NO=I=jLyGbttaufIi&cm}HrN zrb4taEqHtNu7Ve2q-HuWna7fopKkP%FHlCl%@eTppYv`v>?gu7EI^H6%G7|)t|HKd zGoT`yNChzZU15O`Q4B5QVyc0h`Ayr94S5k}NI-lG`0!YGF(uzV>Lf`hC+zjW!D3OWnG#>o&|oSWgS5}PuG_1%5R z{b+r>rUp;d5R)^0JL*0`f=-KrlkTZ9I&&+IedVcB!d3f*a)7)Wer#AfJ3`B~v8$zX zAA3$kq|DDwZV?|Gy$wS29PWCyw~{ciUA_}S3FrP1WIp?0ia}TuqVxdH?P(tS4_DPe z&9t;e+$#E=8h@3?wDU8+sq;C}$P;0n zaPpf@R_S*U!+qzY#Qd#4M$&u1q5cG0GaE2_s2tFL)x&uSovx**qiPFpS9`0Ut-LRp zu)CunfVof*n=JwRk-3~%TZM9gT!AfKDl1At$X`kEDWu#|w&$>bxI@%hj(YPMWy25|B$y@8gn_g%=_A&F*q??6KEEjH`c0KBT;|tO)>#VK z8W9yY2xDk+@AH+MkYw?_(HPr$TX)tM2SBx)2<0(o;hAfRojCi9I|6Rkmd8 zL%!ABMErsiQY+4bQ!&51MHK~zbmX*Phs_zX_T#uZ=K4P2J~d@QEoe6d5o%;V#4s2* z8F>7r>)dP{fX+woAe5$y98hGgio-+9LZrB*us3|$habGwC$Fl{Y34mj9rHx?XMVc% zu!r*_h-Ky*$5omB(U@;kcvHGy=qB!DyTa~3^iXj4UC4cf zM*Xg8`{C*U-$km)Zrk+N{^izFr{ zymvtiz6im60Mp+EI+efsvCGi^dcN~Q7zY|gMpIuI$8W#WEn@T%^!vFqOC=;)>!qeh z-YWBAK6=Ow!y?B-%7bRYvlU_Q3M8)aJR_4kXivhvPrTY`U|+0LkpMfxo1|8z$j`GI zRT(*HOg^-7Tc{7N;CW5)`&@{%+obJ&&)QN6e|?BmY7$9jFn^L{kfEuMhSQbDMu#6g z1MT{^RUIq!yY_(n#)`Rf8^hf@NzKJXaW%z4JP{&^;X8fnl%(C|*(fl(wq-iZ+*c$&Y1KO4re-$WR!P?e=~qp%0CG^8U@7x9(MUixT-38SG_vKthO zw2L#>@tz-Pt5|b>OfQ&WX7g4P;x&$(i%1}8e~@neHAm)!E}PC)gvFxV3ndt}yQ;xy zFJ1-ON95;}A@nXav6&3(jf|lPcQqxcy%#jb6H$c>c`RF(@|arveOTz3=fmrM7)p1E zx-IhN`h{` z{KyG!fKq(ReS@^>8!{Q1j7)J6*ujW=JcejwE5)UCt%4qpHKx92!YXi{Xqn1uO|fOr zz_OVLk^9}`d2~0Tl91VXnI_zaY)0|q^c;3?vXAEw1efCgMSK$o8iK)>Jx+y z1oz7_t@jb=3EVl!VWoU~?mqJy_Y4zu@6q#0%0BS9)wRMXk{2_kUH?$Z8Z2f|En=dH z3kPbm9_u|dc2lzUfysxc8psb-tqqT9G?p9WQ2x}|oVU4t5kcg{#n z>TH#`+`IHz0~#(dp}S$Lkl}nDmCx*(SO~|25h{aQH9``H%nyC}z}~yxKiK^a+&q6# ziS#;ji<9;iA;oQDomx_Z!AiTqMMHb0DBl5xhB`E zk|bH02qFqSKB9zvt`7s`_@{jEaBgkf9a*@7H{?fEUa(J}L8U*tK`AQza=UaJ@EkT6 z9$ABHSFaHU^@TGE_pi4rP;DL=jQ?m?*(Q;(1%$Fni&1XuzJ z$dhmexOL;qEW55)3bQmkjD!??(;F&Mc4F`o1O`s zm~G-Qy;s321nE4wz8RWv-}*9%9&3xDs?`;bxczl-4A_-!Hdmn7^SlyoMTr}n?`@s@ zxG5jVRcbmF^F|SY_Wk51-JZ!UsbAQFixoKE#_gc|><)G$HI5V8SmpD4(+69jw z!%ExP&>EhU9=0)Y*wR1;+d_Z(2t^QjC)_2UB_xY)M5BXij^^?dY=&jquJlYlca zMmozNY7>y5R~glmGtrVl0IHncAthUddsAg3o8+@xD*o7}zdlw~vvT2bD5CH=iQEEB z$GkuD`Rp%OVCdQhKwnO%--LluP1TLavqRAs==)J~M3ATZquvUKSM8FPSs_JtM7}Kr zMs$;Pb@;X5b*iBU<~EIToB!kx;yiN?>#J#01@L;ggIu`~)9+Scln86Osv3I{vFg*o zh@Vul+FZ2Hf07#$4Ov1vwt0k>nK2SN+-m*i7j(RlfzVMWV0Yz*Gu^w;rHpoZ8vg?0 z*ugcQy@O=6hCOs{@q7bYA3t1z!nf4llSVE&Ss~n~OYO3R3*rIUN)#s!V~A5)#vESv!UD zsNIA(4l!hUQaKUYW3If5SFYDu+9R>#WQh>zAX^zKWrN=N0a0F|juu@Rg!HO9kZQ`+ zPifzfTzQHq>XcNB;OV`ype_aqwHKDM*>yVO=X&sch3J=={lE{pRcSbLZTBUlSe#ZZ zl!&`rOqw}v<>Q!()Z7<>pR<8b!nvtP4FdI|8rB|N(kn5rCXlJ#ezYjn1U!q51dqYi`kxromlb}urWA%xK{xZZ0UI~T6tpl z)Ayw3h*SwEwk50VVU}&$9atLm4k(E<%zw6+6dje>BfjwKuplq(2Zfz@=aMWPWjWLp zS)$!i_{S0(r>&Z-X(yjXHS~&Lt-py}oaDg=kxprpUu}Q>BSk|>mUa9AaK#n(N{I6= zyFvA}36pMPYd2B4g^7}m80i5AXdL0R#UzFoCa!G4`9xAF+n;!e?xdg4JkGFPE}ipU z%mb4rCmB>?5c0+UL)%+NRT)M5qMMLXi4BNSZn{Avq}g;Yd0pO(KjS8{oGZ z%%nNJdy0Os;H?QV3}wquzF{_;C70k|t}`1-pKWCdyxL`XDut8di|QsS`{8M@sLai# zt&2jnXx-_!j(&{8jZzYsCAY?HsFow|zcj=blP`mzijvWR0b|b*Mzyj}iu@9!aF^tM z3G0HQx6SY&|C2G^E!2I1F`2F`>>?9-fYGs zK7?#A267LiR0{<=zpta}my3NKe0CTgz)MH76IM4`8%*DOBjL3w!NzwI{8o*OU6Keu zU*Y^h^HU!qTkm`a^XVTj{!)ldN zBqEI6r%w)aLId9@!Vb@YKQBf4CLjN+1Zd^oqJYiW#&qs4M6z_ZfD+0#RIeD`6eui8 zQ@@O=QXn^MxgX5AY=KrZ)B&F$41a7@8o7C2i6ksnX36xMl>couQl%vBeNA;En)M55 ztK}lS%Et&Y_F~n&Q2pn_yB0u*?0lV>ous(D85EXbeC7Eo{$wGr-Qfk;#69D?vXapl zn@x+}XeE*2IF6_p4&+ z(fw!S)=2sV#3KTG8MY@d549YkUuk|R zMhwarx-)6VfsY}K2BzyoMONLFpXdJ>$Gd~C4-NAkk|`%Vswt>9?E2tp2s@@;eKhPy zn*4d)n70UTH$v&DC46~l~FW6#kn78TJvBg5s2pBOrN=SxQ1UDNT_8{^Mj zCEH@i$HZjJzDFij$ZT1c6Ny>|<h5dzO!23pq`O>xHCWYCFYwNyQ`V1YCX_r#~zx zT*s7rp%jE5yOjBM`^)zzW%*#%LzryV-2t(2?~uD)h)b&OxQkbH`eHZDs416OdN;M2 zH+f~H)33le2A-G~9`KbPgN^jzGd78?w+)GrOJ*L{QYUH~Y*Drvef>9{vER45-rvsA z7Xlsf@uU|#X?{jrnD9+yi51e?@q6Q`vVK_W>%t`}l&bD3M!nvwGj7q;+**Au{uT$b)Cl!=LGvBQtyrWc#t)d57)|^$3D><^GMn_tC72HOJR+R z7YlE$fM0OCfPDJ38x{F#mGDXPatY0BwoveH7aUht<+Z#SPuSe?PvcAb7NiWnazQw7!-)SKV zn-q9HJR$2$Tg;W`OKQ@>n(`vYQa$3KuB33)LKIs5;&>aev^+Bbg6vIZ+5;GYg|v(k za-jzG3={uX(pW)u2-ER~Y;4C*sbE$ld)=BW$t@9-O+HFQAmQCLl5mdMXWUXSZ7qQB zuskyL!G>>}@8KiaP5R%Ql`TIo5gk-AwUb5YCH$luFg=I;4+yEaa#>8XeiRF{mx>f1 z_KZ9shDr*yS#r2;PAfqfC_6qbQSC~9_d2V(pwjJh-$&R>=G9=mUZeE31NY0`-Ywfb zp|?hxOc|%|$7U)Yy|_8Y0{b|Eh8Vp|nd&|&usxU$-H^r^LS;t*Ax-Pj;$0hK_!f5Q zK2JEu2!*AmYMuR}%@y~8Cz|}QUhFeZXAR1Fe%TA2zSkwk`|k#*Zg!J$D4{iDdm%;zN>m z^L=jHjODA<*i1dX1T}l1`YmJEpSbUQ(h3BrHJ3yZQjf?fF-tPis6{_gk#}_;(|_a| zaBpM{>*aTlufBf~R?c|$n~Umn+Dj8D-Z2oTXjgwp1B^oCkD8K3-!9xOd01 zCNlYFpuCysQWcanlOp6I*NVwT5sDP{!-bGUm{7d)yYBq=e3{!+@&hh27hChBipBfy zKy~0(OR@<8fVu1u3J?ZHx)z022doqMrxFO4RY}c3utkLTIGO)nu1qGoc3p^ZhVX96 ztG>wXQ~;22sleSw2C$$fGU#1)K)R~3KqF7!D65l4ncpyC`#6XpZM`e|E}a9c0w!3Z zv_V0-{u7`r8Kq(VS7IzaG*`FZY%pzyF0!r^i<=*`hJ5t_&&?Lu^x*Qc3Vw2v5VKn8 z>YP#dzVhf_IxN?Wr@D3cUoOmP<)#3U*c2eB|7_6Vx(Q{1pg!cIqeWR_-7I1nBV>_+ zC<~dMNYJnv!*O{HCA=|6O;a5@A3Sg;`0q)hyB!cn4JHdz_-KD?RV$FPMwp7Y`SCcd;S-a_II_cn9$_&du}CA@kDXUAdR@664l;&gzGK2@Q%7hEBbDrAJ@lYlsQ=? z(>va{hz)?Pz5nHV9l%m}|DUx0pDnsEQmA%_|JQWPqm;K|WZWqEC_~=DG&12w4wLUe zMSk~Mw-L;z0>J6SBLz{*f?j;Q3d{|WUTaN&;CG(IAEPec&jv^D1Er1oEdNU~hHe;R zg1UcW@)ftK5=<-Oe%PMFR^sU?05}50qXPdmF^eE0*U{YhcSsG8~W6C z`59rrysiF+pWY9Ii*k07T|tJ$c0Ncz;a@9-x;}I^yr^`_aW{~ERg>t|Ho!19b6D`C zfD0Xu)0Lp-s|q(U&d;tx@sOv-q??L$JE=;wr)w^FKmND0OvnGjw2VLss-2C>qyKYS zW-jDkTE=?rE9eH~KjMp+p>rJp?*!O(=7!mrlf=Ag2Ku5TWyJ2qWo*1r~{L!vF7?nFe{_ zFYhXMegNVp|5xh^IZ7u)-P=^`m3DSh4b#^Yjm5iGGlFK?gENu4wt9^iUHjevp8LTD z=H@Beh^%=CK`dWz$Xd&+aI@}!^96vCAI?+U9Ypu~bo z!sabmX4|ig?X4%E{zWi_rcNf2foTM^oS%YzbniI#IQI|%XR+m{i;ZGX5xE-p=T!OM z;Unuqu8OD(sO$SuZ@WiS^W&+ldspA4t02y+3b#}!sQ^($sF3x1?k+Mg5D=@O_>mrv zH1sJPQ^;xwe+F&J$&tPMRORbFyx-a2)CeqH!>KC>d_I@G^-T(X<(i30^5{eNwism# z)>%%%c?mqRQFM=a>&}+!k?P$d)PvuD{WXFvKWKJNP8xM%u%doN`QfcH8}NUVVfrqt zB+O+@dMbp=VQ7$pc|e7Y#8$mynjMLT_~58PZHU<|!OX}r;LT%*sp{N@%>ZC;S}2k5 z7k>5om)Prsx1=-U$nCg{3B$PtUQ+&f$0>>=w2aJQIB+u)P4=cexoaVrL=!>OMX9gx zF#P{ZhRJnahGw(c;T4SDA${>VO)AMx5Le!fAJTpA9@CsYNjQ*4iFmt`pV#S@aC>_3 zbx3XetU_OEj9v@rWve?`89rav_43zmJ6(HuAnoq1<1Y9}tY++|I`og=e{RIrNFPZ& z2fdxuPNK8-WVLua5-vjsx&YDqbDD?%nqVx#9FgC0Tup{%`4XCjl#vjR0`k%!lh`v7mmb7c^DDy&j@#BUI zMT`B<-u(c{$*yT^?2?P98j+Z(}PE(={$oqx{smNO&ymvA$?U2NioSL5?TDDg)5 zoW$Mkog(mkG1YNKmMIs_6ETf^G1Vx)J1qENx~8!8g}nWp*5ZkBo5^0eOlO^>G@AO9 z?QN7dT1${A=s>=93Unv45MPttw4%FMpyukvU}QMv;G2@gUhznE)0-thXL6oZK@!Y~ z)?a7+W}Oq3=x3etI>V3hnkP9^cX z@}olVsl_-q=vwx|W;<_;BL?NOq5n(n?vIVXt7P>rVJyZDcRPE!12P+`y%Ym+}a*4QYYcbo;y47`LoY6f9o6A+AR=uQWPhf%ia`2LI&vkpHgcQmSRYL##_i~pY%B70w z{;;;;93~h!%;*n655tHM`XLPW&TUJ+cXt(Gs!|P~BL$N3-;;ztp>&faTto>9utmbs z?V?YBe_NcBnQ`8Og^t-_#)~d)Rca$=et}vkpy&a7coE)MFtqfXL*rZbGqHHQ>Xu=_ zMgan%`vS%L8MlZ`KFW2i?%u87ZDiv5-@IJKh zKhYMPeIj(<-`xlE`7Rk{%}|=Kz*bS^pVpwaW%R445m|(Y(l7nexvY{hMaku^GgO*mqRnB$SK~CPr5VZBL~-Z zr>Wb@iIePZFcEULv79BUq;l6f+=#GK>=w#5#_wQQrwsNasV)&RBy)l;_IFDtxR)P_ zM)Uc=;`!{B-m078&N9G=J#*{5vX8@FFR<~*1c1^%HU4E>jJnme_wM2IEOi;r7jb_x z8!fc=-D!!O?%C@ZV&V$@c8QunqT7FozZZM`v5^a(#Ngd_j)7yuWNu z3_fX)>UZ)_iVvuX)H&|Tn$blSL~!;*=w(@G&cW45wX&kFkDp$p9nLG8KU_;`Je?Y> zOdtHhPBP$`R=s;%pWb+)H}6q;Iexf-EgXsLPD1PJi&{{amMt^4HW8~C3+*m>l zp=vw1$tM@M5{choZll7zNEN>x=!-vy41QSKOy%YUW>k4`qnz&CXWPfU=4MX6a@>*`DlD_9 zTBWkQvl_WMN_`WX9!#vV=-+y@vVRaBVrKSCnev0hkCm92j$Q{EAK7Gtl(}5QXr}}j zvqyHCo^p*v>^KoWiC(eF-`F9>OgjA)?jSjMt&phPP*A5r)-$Ki+pg5f=nwjzo>)Y) zHuq?MY^H9f-WjT7=2R2^D9QLVFQ12kN&V1|k%J`JkHU7G)(w0=i<$PqBcuP=;~0y>8W;Mg@>Yi}{DxV?3 zDemxgW2wxIMtU(N@dQuyR)b`9;V$@`(4vvd#gXCWMfY-6qL#X*Xwt`bOq|+Fxmm(= zu*>3vThs>1+Fjc}mxFBdNw}=w$SP&iFY;JO;lJn>gN( zs<9yc`c;V4sG3zJ114-U-ex?LL0w4!w$|A=;Ko1X7VdE{LXbQ^ckXr9P+7 z4^xk*EZq-_4<9nE{hrp&WE=jrUA3PONjD_l7eq8gy%VvJ?Ydd_IUK_^C`b62EDS&{ z$tA7l1`=Sa1Tzi|>;5cV4mN6kA1a8t+h%0Cf#+gJPVUJ>{Bj{iTs_CZVYS> zQWfgt_%4PkX%P?h>b`9)A|+GWLLn*8cSspe{^r{-hB_90`L+zpRrD6=;-S-Uf zs4H$wIP2n~U?E|s4v7C)723y4!Aj2+_Ln-Ijs;~&2t_I->te&K*@jvXBdu?f}; z?o^QBtTcG3!n?uuNk>UV<^i1{Q=7tbHSS2INL6{-Vsm-lY-(bfC*E_z>Op2!@J=FI zCQ)WDkcLs*V~h&_!?^d6y6SPEdY&xaV4Q5ZE_J@mkICSo=miD>xjnof8~ENYjnab* zbS0^RqD#+%+8n-pxu-%+K#CNWV=fS>u=a=>-fLUJVxNTR?UKDR za#Y9&E+G-q^z5I1%+d-lsYFu;Rc(x$X{m_?)2cg$hhxZ>se4UhSx1-1)^u-LO)Emi zkj#@pE!sXWtoHQ_COvm)Y6CMTQ|(4J(zc>to3yj2Ce#G?L>>qJe|dFb>!KE&XExUp(sCQhA=9HJ>%bLJ9AUF<0_XC>yQJ->0o} zy$THIj6Y)>R_ii3ZW_JRukR&#n%i{bh3zb0lxSpnc#!{6Q%XfVLsB-Vl2l{cag3|6 zdzIotoOk`(Eo}yWV4x-S&*lG)SAKr?0lNA2>M_i-=qvL|f6igtSITN0M@R{n(Ux_Z z=!F@O1ll-A$jRV)BU`W#6}S3K2v+~+9fog982`h{4vwdJON~sM$5#LCFZm^JX6u5- znc*vz--@sv(rHa#JgNASwEI!JwO0@~f72-VU>L(dO$`3Q@n~3(Jt>g_j+UkMDn!au zW0y7P?gK5DhGv|Nx(`f>fQg!clmYwJa1ujHU}hGqF&C@Vg1CqJbM?k#Cf~+-zTcxD zPWbRI25f@h;Hdkg{Nd;+q0Ge_brx=sO8nh1+oI5cMiuE9bv5kLlQ(r7t-~rN9uls< z$*&=)>-HD!F)~BYW7rgRM$M@yo_ui}&f_+*VNq~!DfdgOD&@87hDA;igQ4_ib&_*d z#y0i}tIT#b;p|HDK_?F;aiz9YU0L#2d$s|1ENw)Ax}B*PyvQ+zE2;5&Pi}VnL3)y( zZogEM=3$!J&(hP}gA7YnVXmdzzbj-x&%a~mt;sIc{teW-V0Ir)yVa;ems2R~pEQoq zSW0|c*oFhYcTB!15abKzY!vc_2OPhVs%ENCX-iHDR#`tEeZ%VMmiw~V#Y|I9ZJQ`D zZ|zxRyqn#hS#Tz8Q`o$Rp~K~1k!+EUgfq_m)H81=80eieAD(y)de3-R5Yvmnz}`b< zVYj~^WQat{%sCz@bhT6!=(K<2Efa`ap*U207RyrcMPHiuE74VMe--aHx2J7UyrXjt zl}9#@D`~PfZG&MWYL%f$QN1$S_XT4fj_AM+OD5`q?V~IrBR$;u1ve8s?_f(a8{rqf z)Q)7W(4PTxeIV{ynER5a+`UTPI?*9eo1ZP){qb^a(|02?Pj#b5*b>2-8ZY^<>kHYF z2(Bg0RZDxH2_&Wbz? z(6(@4w|1qROVqOg%F1s1DfoH9@Dtzypj^X$AJQu)UxOIXq^5_a8b|QB8f{9J_0llW z+nmkF2Ru^lOxFfkGns`I=WH^XS+1sIUS^(A7Fia|QpjPrOiQmv9;0B8y+8@8iN>Ca z{u0)itA}!X(Uc~GFSCa@A^}|&&v)aAURbx8S)fg;hj%$-L6Jr1f0}|c9HW?-cw zp#wE8TMUU#)I;y16sc9!PM|7T82)$DrHJ}8W=qBQ30pb%fk>D8iM89?+K1GIGo1}y z7eHZrB^D$-dZs_C*x%_K;RZfVJ!{2A;K#N>`PlO7$zv^q34t}p>BXcBb7gWxo zkrgN*bRDCqNCz!e0~AUEJ5{5&Q;nY3KdJ?CGji*#aObXLuW29Wg_jI-QuYtdaBq52 z5DeEI4kv}1tjh*Xc%HA2IOotD1SuOcb1*<(3ImzqbfGHaQ9UgE=?_+*3K>uU3q@M%z|V zRbPqDk_pHDPUeNW6pP0)^RPl}1!xsngfBs0TkHZ^fQw`|*&f?Y9$uWnZp4m9H=?Ef zltM;Lql8FHMuV6mHYwiz1mMIrymE@=*2VtN_=4q&irklD8LfJ5vy1lj!lO^}rv1rJ zoUgVh>4ZE1B0FwUoQ=-XJZRaPTLC9n0Q?clehsx?48CD$-YXERbLJ0-S(YaL_Si~Z zeU~Ccox263zc-kJT$Qk+co^tdmFG40m4p0ax#Xy`U)L}eR5F(iL#jn+Ot|Lt(1Z$5 z94Vs~4K5xDR{mu`+2N`3V$~@jcL85XCq$Ymly#qUlD(P)b(XOw_Zdrr{qAdJ9VDTb zLbNGYI;dJXSfAQ!tl;0i<{s<`R3NBiuB5Ex3yYX_nNLeLUwtRbBB%boQ3;wDbgJbE zbl`=6Tw>XeU8{cNE#QoBjy(Wndn?;oF^O3nj*W%-!k~wBJWu5zi_zCQ1=<~{h@KPB zeZ#kJ3-W`?k48E3K((Ymq(bZlFQ>;uWK1og{nzlY{k;zond-u3Ir0 z$_~7rQInENMjPQcbhw2DDs+9?hmtj*J&LbE%Y?O}hZoxosuQ*atze*CcO%N4Mb0Tt zId%5aMKhduD?cs?olj*bglLLZD7dG^2Mv!@BDQJiLD>4?+gpbSd`488CDax6&<`+l zaY$h#2h0l9utQ0CzH4b~Ynk8T9Rl+6`X@hnD#dg4uOP2WS2;ecus-^xG+N|0n*xdYm2A<7fbAt>M=D?_H` zs?-)OKOzeH7TE=q+G6~8wcd>Gs{$vjT66OXq|?eltK+m(q%i0#z7FI*yoh#TaIghT zd-P(e1WzxTMaYR<^F%M23wI%cVr5;;v6O?9A^BIN}lcj7GMVTO4Jm6Xdg3CVtrHo7}A^vOMnDa-qAUQrAHq%2=p7zR0R zd`>)fw`>u>@(X|^*Zn?AN)(5owo7gd&9d5h!s*AhBddnV_lm7J13Nh~9qRal zI6#8%LOBP@gg*tWJIZ5ZAgfKdNko-PlU`}o|2pzDe9AElopFVe#jvi7OeSkt0#7(* zyf$l$=x+XU%7sBGs6K5Cinypn7nu4iWS7)M7jO6+l%mNsh(3-|T1w7Z=T6%`2ko4* zZ{cWkb8_1;Hz@aDpKqb=#F_%RoT&SZ4afaq;R|4A zOiMLF-JvUPqx(8i$s1Uo{dlpXjH2s4SW4=&XoD^td-dCIm zsn1#;d2%lO@b#d`Wl_ju&CwLoymx1S#p=sn;Y;HF1U)KDQFh!(be}SO7qd-_%cI_h z-yT0*O>9FKxdwTe+7}_QMI(q0<~?I=k{R}`gxh}wI=*qgT*dL>M#W)Pm48@@pJt?L zw+%4OrT_&B;uZ?lSBZGB!Y|i}?p_lH9C%&Qq|5enx6? zc{wU}!~{-0w|TRKAy6GL*=u46x-!CpXTr)bn_ zMVmsPXoMcuW-X05riTa1FGMjiE-Kc@%WeZi+1}w?dNqxQ&x7?f9DjmJ=T!W{dKu5M z4gH`LI97*Iq&3c>Hl-kT5VLe9Dj}SzL)@#^6WjTL7JYwIWj)r-Nr{NKg&K^m<5<_E zJ7P|7_vZh66_D=2pDyL*t5iniXvWVMF7EYeB>xjnw~P#dg+NZ7aBa%!EQ0DKY1#+;t`?gMy9xEJ+3x)5EOdn1_mL4vA6~Os|!wBmL-uJOsWJ zWT9+6Yfi8pjzfH+`6%$1iBpj}CET=74u8zIG_a~qk{A1BVO?Sq6GM`p$cnB8s!I8= zkQTleYEoO~aV8at;!>I{oD#DWS;z&p8DXM3JAeQ#WouT<4)6iEa5*F!x3OKWhPS5`s(AX_G610p&`b;_x41*hGsG!hDz^}A z3(upi_%zr**>fynyAD`F4Uv9eOyU1^=M~78B`+z^@`Db5- z!CUUvbp4w=mgfl3bqni@?Bs(;Num9?!2IQ00r4*YJ!7WXs~HfpO&u33S0*KIN5KlL zJ7ixds1iNHjJB|bWXKj}6b2$nvZX%%gLa5VT zO!pD`qeg(bmlJbk6L*fa{61;{bbDz9fy7Kr1TykDz=9UckMIc>JlUI{mx^#Du!Z}_ z#48SK3RrChSzRRhk+anW4;>uDeiCITJx`D(oO&E=8a=Q(WO;PHE5IlSg#Dbw1gv- zp2*IQ=)&`el=2uwv@kt3KE*2%D|ToQi^QrZzfos@5=M7PT)xVqlA~do7|hN%m^A6~ zA_Y)yv12ADQDYO+<;%j3cUE*kNI`=N`{!qJm$dF&nv&rB=piQU3WH$>qw9kD&_^sc z8YN4Q;^EKI@6$B2pX8}Bv_|wFds7O@5M33Z9wAkgJIXbk&Ukef{NXnViewFXunpXk zXe#H=%ZppOY#wiiA)0n#V%?^9^@10AwRTT+Of;5~+B3Vij@{zfpQa!`i54Cd%1&^q6$6dt#rsb4diVip3 zv@iqaY}O85I+oYYpw*=IE8%setCHf^7jZz|2@JzZ&tpVXL0Io+_KPfdkx-;O0k|l} zJj#1zPs{gRTkDCImUd%FcF_$#S=z>m2O&p!w6SiEECGZN!6ioq?|*u=Ws#?w|Egc@ zikZk@DZ4~X<0yT#K$f>u3S8_w5Q}Oe$KhocTRR_l{%MdS>P4T118wPn;jec`^=rgu z*cU?O522e;0@nJk$dmW-?$K%*3sTY@Q*G^m_2Mn1APfa(tk@tk4U$Z|pBsQ^?9NT! z=+2MY-b%;}zPn!cY2xu0Bi~5smZQ5e(mSBp;q`eJE!IX|F7@iy)gT*#TD!MeoNsAg z16T>m52%i*?My)?7adQZCmO}W=uwoGqR^RFqBRt}47o=rgKUB~4{(^*c~)FU2e~ml zEV@O4>;#smeH2``#$xi6YUGe(7-IZOF(k^o04ZboY2Nyiey!#8{M>eS9$V&^F6Crg zVt!ThlwGx3+SrcYS8}_>d4T))i?)KvN(nVdQj{Vjwk-djR_DqVHs zo+JP8sE|VUQY<854n4+6?Af~M7!h_!O3uY>_P+o?8CD8lA-u*y4&4bC{p$j`ND7$} zH$Dm2em>L&(z|iL>C3xV%Jy(z8rjFlFM!X)crm63-Wnlxc7IA z#vy%nbH>a|%d&iG%p;PaurNak^jEZH2_GI(oCzLb7#5YfG-y5tFmI~9A!fj#B>z78^k9N-trOY z+pflVK0o-?6KE6Q+pz^r9xjRGFNno6IM0d=5yTz==jhd%y@PQhSLkm`DjDEz zRO7gzKH-oi@TwYcP_(~9NY>OD=MX}x#!v6v9rwa<7R$#Fq!i+|x*y>Gmv*$r zw~QAX>}*#KI;E}xpf*rY`yBLHY8zsRnTFh!7VrUcjd%f$e}3t%F#PVfkRuSDf&vNc zMij_D*keU{qktvC9Kx3eNClm?F!1Wk9GaU~gQ3UK*vz8)s6kz(Tvxhd?lh@G=KN^eFHXHm1o3ME^AiAhGq~>|(u2GjftJ4sADR z|3o-~x5zo8U?35_U=LP=2H*ZK33+>gJ#;?tL>F2##69rE@}XrTtY7S4sIx2m|NI6> zoCqp(;G@3;To4RVY*daL&?}J#AV^L3QJ6*o(EvSIA$041JQ(lh&QJO7x7a1%`|_}$ z_>VZKGI-%i4Ed=T;uQKM@T+iHXT$%~m)NhRI!^8w)Vf$VclyK9S7mqpW3_$o2~>}!RYb2$Avm*&?i0Wfstc~Z zwwEV65X&XR-)VF9dyy)uQxlqEj))JqRq>xI%6UZf;iCpY+s&mj^glNX481}T0T^X@ zEk|XZjJ&SilAMocN?>k{=6(F>&_!`Wbi33MQYmzg1Vkf%F}}wB zC24wS7CU7gd~F{70UwJJEY`yLoe(N~eN<4KTAj#cJ`Ex@+tDf_<)D924+Key0UUJZ z%eHo49>)+FU-^f5C0!4GIdD*To-H~7bS3>(-&@7H<46|v(3F}yt_N!!+aLNF1{UJT1JGLN%tEl9q=7~+AMd5X*YX%eAX<8dJ5oq zAR=Hr5Dk&&{SYJp1FkP3JO6#7`>8-A+q0&3Y7AFxw>jzKv_m9ss3OqBk6R z1-`5l_+7!e)@n>4c#`^$%D&sEOh_bw%_oKU8ek4}0WhA1v9uxz5MxUw_AgD-WIl&1 z02pO`0x*Us-0(~O+4L4@dZC{m5uy5bcS?X|JBw=p^syG8gB9W1UrX8c2n?10{WWvN z#_b?biT@=F``^t;mI4~UvVvLE1vG&9F0lGOv>DWx02Azu*3rLBZf}8F`t=}KDdABo z&>N@qkb2Y`Jwug1+>2y;6y2_hm4gMAl*oa0)Eiky0L)Cl&V2Ly#qn0WmBtGZXtTh@ zjX_&&$$Bf`5hffFO})_!#K!?JO_F|>--nb4 z$&9a>nPBqGzy7Pu|7qx|WyWF#jc(P%jH-x;lNW$Z1NaE^k6z6QL4?9Kpsc-Kn0xu3 z+l30O3>`lEQwhWXc>B_wm|gJYGkiWk&tVJD&f)-$;F+T=2Q+;3lokL^W(S`Jn>=QT z>!lm?pKXw#H~ec>(+Uo%dp6f5!szNzEf6=f47r0J1)n{FZmlZ^IyGp0TasTjSwh?* zA8wP-XSrY8ST-M80&uv%UHm^Fae*XOV5GG`Yl9df7*s&Ew*7?}I#{+3Sgd*#?QSRA zmVgyWIa805_{PQ|O_v{dBkiGFV3&?Aj*lvM)a}@YL{{LBV50WMbJPx`3W8nQyEH`h z^xs~12EKgWK{$!xBLRY}g4{Ha_Cz7UB(R4bz`-?|+Xn#t=ruU)*nk!hpEOLMN|55U z0_nI=Bk;+AuC(qG$B82}!T*S_*wF(&sokdqMAwQzTSEIid-bs;b?4sw1)}wqP!N!xQW1)zAhoeV0j%B0?oVqF;T;r< zZw3%;H_(6(E#F7GDlrW;KQK)5msll`X8=Qa2*h<5$@S3&=HZQ#o*rwA@oqxVGsIlSi z4h0Zl^B_)Y4)DMP*4KdLt!aIaIRhmD6e}N)-E5iHvOwAa3w3u<<-lv>(AV!^qhS5< zLIq|9K-Se<(Im0aL;`mTz|rM~bY?($ipAvv1T?Ls^YkTP6GxKzLjg^@rgO7V5_qK2 zL;*rPRp4z~XlOk&CM;I44_m(Aev^P+kHxUfh7_+YDd4i$Lg&w#{PRo9 z28baFxNvuA^8v%t55U*Hrg#SM$)MOWz{^?Xjs^itOW+!1-WPwI(w+b<>XIaQrZQp! z1GOGf8WIE-I&dNQQKTw2u+{;6`v|)tM+HGgrv^H(jSI%8fI*`IqRsl^xUpn(Hh=xs zT&qx%fBk;J-%|28;BlirgjNCz9g6)SDFUW4X~8y!-jy z7L0j?qX>-oj~Mb6F0)R_cGgrl|LXlyXV8eW166xtK#bsk`p1+x5c%(FumGG+sJ=ny z9Z;pVwKryWP@chMANJp40)t#th&Leum0xJ6z{I|JrhWwQYoyShIdl;3zN!YH&N;ir zV{Pl{sscRL4zLS>XuFw!Xj7t=MX-K}fKRjL1A##CFwjS2DqX3`)NF*fZ-_F=X?NBV%0>X=kzkbf&-j!9J-;>MrTdT_Y?Q@=CzT((OR*)jNslGL zWZ&&P4h55Vhs*%`Dh}WaZ2{VF!Pm{!6YLED5pWx{1Xct0_p(J?5xf)Z8eOopJH$V_ zQ$;}T*6a8u(9JRtKQPpc1rTsSrSh6YD?yOz+hd_Z6iUSGhLgN@vld{ZSAq^~ec&r8 ziT=mXenUf>59kAD84%pEwHb~N690GB6i|(qMkSdbfsE1S6bf0e2v*eK=QFjh(!0t$r4w+%Smal3L7=^%;Y zLr1PpbmH^xkWG7_^Eh%FzizLfW2!#I8aOSC=oq!1v6Ar#jCbxtRwi1BW0yJ66LF+U zAXdOqeM(VCRK0*HHK03-g_{m{o$Rm7IOE7eo`{HumVb2+V4&*E=f)6P^7#V`&128^pLJt-1>d{m7 zHYCCKajRX zI&uvKwVcp6`@Z0}jN_K+2izY`N*_W$7-pG7zey$@)_RrUqDGeh&?3$0T`OQca z4t|iWQ-+QD7dQ;7udajBsQGrDa^r#?94|T=)s2pB-#*#wK*)sFQ|#3WqQM%!~}u< zP6pf?o!MIvU_JNmAQ50yq|z|4+LeK_Y~Oi5j%IB>B;?=dNeg};Q)*0z{znP?VE%8b z09bQGRM;BT{~KlT$%m%TccYVKvws1t-sFB)`hC`H31c?04tzfTpG10s?^oW8D|=mp zRF9^6wlZU#B?9eCZ+vfy&zAtSKt&SLg3$Dtt$F*fpW|f(Kh;Ye#*}McUAD~`G>==2 zl|&c&REycTc7>hnW0Kik<%ybZ8L=x1fz=E$->xOl6vj! zc}!gpVu8XP%uFQ;Oweks4a_~=t;CraFJtM|slj-nU3KAHQ&#>h4LB4h`k!l3e-4sO z6q`NsEbl1%@udUfW0Fln)0Cx0Lv7^`mN)>n`^+l!%V}*^NAOHle0r^|foN@M^R!*_ z^?rlL?>V>2qiDSCr0+j$#}q2Jufn!$C%HaO&3um2D_{Cl)Zu}rU+`FG=C9~=Ri?^m zWqjI0GlsvHrq=d43f5yeVadN(qK&{<(rbj9|5`m6x3sXG7B)S{6=d~9-5QWuB5u_3(k?!NXXO*8 zvgAyeUG>%0GZ{V;MVYisUW?}w^DfCxos>OU9E}Lv>-M;KomYvRuckjXmB)95m95F# zsZb*{HyBn3Hb^Da(f_-wQnqRk2VH?y~XqQYr82@~YKXb4l z!)rH&Z}4}Hw>SAoX|aCOtzvUSMvO}wx=n@%jY*yQxZMsg1EvaW_CvvB+`nRsXdCz?mwHE#RJJsW&CxF*}J!JRzhL>8cYa&7(xnmpQU{yCct_IZyg;c9}2ryvX(Tf{=X|xx1l0!G|t4C&y*( zua}e!HQTpNcH`2w0dHmxWeWbYez3#?<~x09gJ5c1EUY?fraONCik&Se#Tf?>@@2-) z)4E@+j5&NZ$cg5Ae@)cO@Kku!+g6pH(#VML!aABHv8lI?kk4|rIdYV zk?JE*iE5eVI?k^qRr4~0L^t|^$EKrmhrU%|r&2}4QRd0pZLXy5n&ZR46#SF+jB@Se z&MznD=E=oQE4``RPR@;hbz+rmyw25NVt%Xg+1o4Ok>cW~4li$6k7hPpp6>!`@TUz; zH45x(CL!|)FP=WMSf!rgPPhI%Z=C8bvPivdp+^%D9rk?OZI!thoJxIZ=DPo;X>Ps~ zQwI_LP6OS7IC%lkZ1;anHlpP(2!c!k@E>DvVM6x?&KPv!>_Z3LrQ-F)7DqDIi;5p# z>YF?c7IT}WszTU%qPlAiKiet~>DQQi-kH0~j%9L7dh2=S&`^2qQS+{cY&>iDF0Xr! z`$^~Q1`x;TcvC~qsN(EUd+siHrFW=q(`tn_$xVs~v z>qEsKt8HdFv~aI%4rNZ}sW|E;G(D;`k~?c0uh!U|4SEIb90@F}oU7~Rjoe>m#Wv&y zzjZgw;Av>g{Vc-FsYGCl=eTPN$^(a2q715SA{%6gaIKlSn|BJo14vRvI-bNMi2qG;Y64D(i4H5>@ zARt{z#|$Z*A|W8DfV6bizzhwd(hNv5q_lJo4R?=m&i#G&yZ8TlpL_oBIgT?s)?Rz> zwby#zwLXd5@}wt{!2+pitH9e~9Pmxa+YFbl11}U4u%(~$XWLB!Q|IGJiE5xYC}th; zmV^E<9$Yr+D!ea1{&M3u{A263c$Y7@7cMme==E`%6ZD32to+c;d%2IU1>ZR@tvqMy zSsz6Qo}HgaaK5r7*5ZFneg~X=+cUSa47J6aruB+aYI+~^F`8GpFqAMSik>~DbHkQ6 zKQzGUa=m%dmZeIEefx4bbfI|im_TvOR_8aBg7 zkE{bGvjks6;PHTk_En6V$9=~XgD?P9rjciu0FnP=@ahn5{L(LsG2&!Cz9Q2FUlS(B4uOIWX^2x@XFak{V+)rj{vFfI9?#}drmtXIGa%*>E=X7m<+A^H&6+X zyfWYJxz+c-Sb!J2*pE~|^Lvp}!k2;Rnme7~>;+7KmUs&y4`5RJ$CveOVpqCsOx+y@ z4?j~-{XBab*&+dM%GcP|BPCD_IO)BzXh)Cufqva=C}$CPU%j=#HB}xVlI6#&2x*S5 z5g7MRg{l|N|A4+7?#9FU2#A&$_JBQ$cgA?>5)SnfDUCefuEqm?Av$Rwr|h?k2~W{D zWvbn@yd^S6Xj8so|+>l^E zT|7Nv(PVufSiEZM?cEZr`Ez2og5Nx2t-JpRF@U6#!?Wv7Vf|iR&RugSBx2At{ousUdh3!BJI8(&G!yX2G zjVpbX5}qo|TBRgvS^sWLU0{)Fg};(u@_4lA-1&*xI^MvK-~t}`2*h1+{BE*e9v8K6 za7Y!ne>gG?!OW1fl>k0va8bhX9K9khXuINHDPri0%tkM>-Wg>TyXfw(ZwO2EDhkH; zT)dYg>Y8of;l2!+OhUEv9XFJspn|@`oEyVvgFA4|16O4K#}F8w!W$X=y&vALs6F<% z6p(ou#jx79Onb1R2TZa_+V=xk^ z8mjLUEYgEs?Ufj5JT(sE9xod@chbEqM5AI_chPK!?#yy<>CDg788N)F1|!iDp0Sw7 zYZi1GQ$W~FOK=)(i&7MqnJ=h$?xkPTH)1Yfvxs~kF=$A=Tf}a>fw4pPoK(h=vo}%4 zaL)B*R*x|XJqdUiFZWBhZOP^nX1E-0MCGX8m)Wx;jH= zPEQSHi{2g@qp{EFykmZiK*fa?wHv<@R5&hh>aAX@6+yrX?IO_#(;59b@EdS6h+{6T zYatL|LL}3;PHbb&@EoM`0(r0tzXiXK~f z7Fo#XU7$HQ-OwVMjcNjQOXnpnqcg!L$Khe^i#7*L->T_~Q*bXwwJ5_Q8XBw^ zgS(RfvAy@pzYLCB3|i=;&)1b(xY%ZydH5=^_odegu8-)}*c6QV-dVR6KOsZ}xt`I; zEdc0*t1|Di@6gIurxw8U5j0uBR>vOR7aD|%d-^yVt1SiJsS4v^w{57MvPknqhma1{ zir@b%I_Vh?o!D}@J>;C!RAIxzJCDsc*))%I-)<)C{hd9k<7%p{G-W!L!PbXNO;1U0S4bbEJP_A{9KMb2xR?qYJhbAE$b7 zO@dKjnZeZ}si$e`>vM#HqT9B>21m??*9;$T=%pBtG`aY|k;tWlXkNX*&Y|K%ch)SV z%IvJOfGzC2;m#{6^~p>exZa8Y;=K1Af|YjJI5UXAUJ{i3&eYEe)&&f`ClwegJY#9q z$q@eftah9^KCTSB^@O9L6F=z%%QhePG-gJMw4w^ao#2U8rzVDHp5i}G)wjo1*LN?x zSG*X-Oe53uz^(6txHAq@p<@GY5exEZF+gnFEHbd#w~2SP2_2>`7CQd&a-%;SJuz?w z`R?l}cI7$+`^vBbT;Pt+dgIM%VD$1FubtxYxMU;_udDJR9s8O?Ni+u%7iqLcqP?<- zE!CQyj%AIxtXM52YkdoW)FPhbJ0e^+?SQdHX%u_)-N`_`XdilLcDjv91Sdu3w-! zazPn-%jHPUDcy!5Qq-$=6)g=ZYgOV+wMhNwuuv3g=_VpIMXxUJ9Ys$42E$}rqew4} zz>vABDv{CVgl%s`iuvQTm4ir=JSDrPDW)Q%Ze{auv5u{{SXH^;dRp+Om% zPK>wyP_1shsA)GDdblPNzgx6GBUFo@E|`qB)MQR-bZbK}jKaVx!BYtsm8jI)g3SdZ zwL^=&J}nZ{FqCc4!1v)2-NgVYsDx2!%^>Ef8n&iaiB|7=FH{=a_~WLKEdX}qRbKY4 zMSU+HfD9M%iqWelX+l$(VH$NS##QdIm)z!6W%Q8O<3-YU8hU-tujLXyPAxNuMB0r! zFEO8cieJ7Y^WZ(XPE25%h*dns1)=jHq7MiKt#eA%`e{ieLB!*S}zvH(F+x1#m|@l5>{&+V)U#PTRiW z48H5BG(!>c*?2Mrt9#5>*VC{aVJ`BJr2A%}0#^OlSr`{DMmr9HGbkxC<&9McBhiQO z4AzXHy3Cpsmt#tFH)#X)%@9dZh6D-Lq;Cm6gsx>0SAuHTaB-DLTm~yqD@UWB6r~Xu zKf%r&Uw`%t#Wv=)ND*O@hYU%AVWDN?mHQ*UOr6TSQ{Hu}Hk?{(H-s9W2++LmR~;Bo zmK=zy-<3RBZrdb`Fg@8-sr%W@XuzcB5xpKjda=&H{;GQ=dpXP?EmTl_Fkw-^d~~3v zB(mxFD**YtN9$&_FYXt5lIbvPKO;6beBX}V+ltq}X=Iw=sKWN{CMwm-w%#P`(aixv zpQ+nIO}(1+`#Ix#jTJ{?oDS!5loFrAMxYoU$OGATA3m`+xi|KQPi+@}Yrr$!6k(-m z8I{QnHYLI797|AqdRwfa-NhoMz!oYyT^S*}7T=z;zu&nNJ6U@oTXm8`$Lg?-rsAyq zsVRVK3>}j+Me4d%a`LkO$bdA!wXLpxFqCrt0lXs{PY~n>5UbhYQ;-`K z;0pa}=;8BG%)hjKBE7e1wo$*slH1x|tMJBpbzRF9TVzk3RWYZ1t4@bYAa9$+`d+r) zUGRv&tQ-+0=;zQZ%leR!C?I^6^-jYk!^nMWLuFuE?CwYKl=EUb18%eG$LcSg!hiD5DNp5#U@E+A6Uck30 zx}p>xduhOvBsSH3lXi=LIMl$f6BEh=?<=@dkJvQJJMW$lh~X2UVL7LX)82Jm(!(5S zIgS-p27MIV(4-Mp;`Y$Ep+2u-SKWh~4MlK020hbsN)^+v?v5PGT628s8UfdDhY6}9 z3}n)8&1H(^z~Bq-jLC5>SUaQ6O`z^V8N|5JpbN1xg933TYt zJf2ndkybPwuZxqLbAQ-|Yz&EAmmu%$$5fBWZa?Esq;QC9l!QLl5SDVEy-qIL8$`IW zd91&GuC*asAowx#huDRGhBD6cFc#RdY+lA=CTeS^0%W19{@^&#DqH@y2Me+4fPH$xZ=W8q|tOy3&4Z6fqf7AMC7#XYqSVX)ma38z!waxi%IjYL-aR z(-HZsc%jmhW+%}K8ojJ`O^jji3fS;gMpJ3e6`R9uj(RsCsNopCZqv5n5t z#p-?B4Lz(#4mU%C@`A9pmam-54(Jp=Hntxhth|b9 Lbb(yEhe*VZu`m6KG_QyiE zT$4P*7^w7%@Z_yT(sPCr8h*zZIMlG0j9nYO5zq5TdHp@1kWR6D%fNMhHj z5HmwlJr6NwQ%z}#sp;yeeddR%8jr}}=?@$m>JZMx2?2rzg?X^%OqROLw9t3#A+;FBR~Eu1qjhL@IqVlvJP9>g_lCqs za*`oUr}g8~Hy9 zn$K}F`V5_Pwp59A?dalQ`atBRJnq3Rjf-|rZz~#VG@1`G$bTh}JpD962w-W`YZEgh zXRF#~rj?I0kCh}pZbY-~Z@xe0xl4q)zp}YniU4%rrcoV>w0QHn2yy7e4qjc3yff<$ z_5vm2q2z3;$@0N>iHoHUDj-7QThH~-X*@n{oYEroEb_==DLm`mJQc2E>i4{pfnJ3X z%lDkBynu0L{Okdlq?oz`2pOb2K42p|Fotf4dckFDqn0-tkozruIQkNMs| z%eIu1{-k8Ig?>1zxq^J)&S#wY@d3Kyb^@GFXP#0h7DP>RVuTX#epwV4JOT($5dR8K z2uzkRR;&(lRc*R>n=IYrnT>eZ!U$s&uxgRu6bvoryP-??J9u;)4Z@Xo8vK=htqcYn zBxM>#DC2>93j~lZJKb75#)Q<2YMIlYmj4Rx1_yw!fGjnV|2J^(1Ms~(0%1Mm(`buJ z{||A>U%r^Pguv-@XrC4DM+}@*cY|M+@Pc3%xL!%DZ`xQq+U|Fk1ylOpY<0*@cMxpN9KR?VLYgC+Wgaam40>nHwwa)Eznr@;c1b;>Yx z73)JY`kpc;a=V^3v;8||DNN*@;P1dNDMCz{Ly<`&eUQseJd1zFN2Cpu<+U?S(BGL6 zJ;l-qH{Km3LXx-U28jIqIO+!YUdN}NU!bpFjnV~U!$?~e@gMT*5xjKz4S2_hw*4O< z0?6lCwO#!b;7{PeoKi2kTAH*y_RYIl`dG!Xa;Er@Ppqw za0X?y$3f0&5+})5K%9Sd!2T}qb*JL|`*$KDEPz?>9qh;OPyy!penSrM`Cm@pHvcQI zORryGP?aFy@H?>qMix()G3ZP&`yZDA{+0LtKvDjgAQNVvG352L{__O#?U%*1Djn*jsq3d0GZI2m1d{YAV-)7DMGhva}_XYmn4~W|Dfp)1g z=brvIGd}RiA3n$L`ELPT74$v4%=vfBIIQ3+_QxqY|1P>ert9zdinxDd`uaam@!v%= z1*;Rziv6$9nSe^`?Tqr>|932JIYG14vsrxqjwOm4%n|jarR8rUQlFbZPp7L?T)$v? zV0?7@AZ(J|7?}XlD*RqAdJrWhCz0V(`rC{bFl&+Vw#{sRKcI{9#6oL~ z`hd^x9c>}^~%Xd%D)UoD{+w?Hb7F;HyU_T*r*svZ<{+%zu<`hApO=I|Eai1fN~mj8|opn20i1N?dM zlMoUNZkO}fC#fC$F94?U=Nm-#F}*qo5B%lz&`=5rrKok=B3+wl{C!!Q7a9b7uh>oE5@vAyek{AMLiJBloC&tq@RHQ@t>aWg$ zpVVMrz420fH~t;_TcOL+WyoZa0SxM#WE{A+Os)}7mw_79dVm33tdb`2q_Txd^z2u7 z39|_Q7iGDG!!|tz$^bymPmh$f_3B)TQ@nOM1VO!Jzy=1~&Eo@*N!boG=5#O!@*{xM z9;OsVUgar(gMop2Uh2L@`k!RVfFO^$jgA9#nuJTW%fA5fBk4HYCccm)$1WG3E`F;) zK#~-Iq=2~DsWO+%a8ZJQ?ZhD+T{J_|0alYznxI`*Pq!kR=8?L+({pqd~q_X(|*2 zV-C^03wRV!q*(R7PQNJ{PcYC`{+Gu9vhV-s2Ply3xBF>01w?uca|AQI67wZYBOE1dY@!oqDzaxrSh-%Dv zM%B6I`aD~nY_&f-46_wagq*H{Av^~|cpsY9{0m_F_jo#00)XMGn8NaU)%NtVm>oNh zxoY(vpY^`JPR?ZsXrTzKtJj(Ot_Oi5!|@RqigKO@dE7DXElG&7LW?Duea83N(+oF9ZMpver=*Z99Oi@PzXg*rDyXmq+wAbN^xirkX_3TLHfn zr1AnkNW6OnZq{EOd{KWDp183A^Zh!$D)|r4?veaBqhge@1uI7K$Pdm=uZ`j9;jmpE{^lo2P(e?G zc^=DwI(Nf${yiDLptNy60pB?yjpJ5e+Ty~~5e@yooqhmiGke@Mq zv9QGXCJSWx?oth~$7FFX59b8D$W{m9`-H7cdSykpJ4nh8BbPpRu4B1?-tRH14$D~h%!7C_8>K(S*MJrC!I8dqpq8Ctk98FPhDPYO_3EcY>@fCuD4T2GR zBq?qGzu^DhE_i-9?BQ3Q-+e~wMdUL1ARI)uE}CVA^L_^A-z8g&Vgm1A-$sMJ%@C&W zp0bBb7@aEJ4zvHi+YnYUINhb56&h>9;-EFaySj9_qUG#Z8_r(pkAcngq zcZxP|fA=64w`D0LOYH2_J(@?Kd$G+UbxaR#R8RJ@n0Fv#EZ~)T3RTE{xIXsl=k#}> z>vXHCHf{xqtRqX!wb8#Q)((tBh<(CnZrex#8gRplUdTt< zBJC@zm*b|b<3<4A#Uk%Z)Gm(oQttwAPc~7;Okks$fb--5C5jpR*uYpnjRc$*YG1+? z8M3gyAUfVvB@v@$l|G!)MR|h1pVPw@rYc%w!Bx&c)hp$NF`R-K8?fJjjrW;4sD|h* z9T7swygnteo!l;C?;2qJ{~qE&D)&=hC4!sapIJ=oLV|jk`33kZ*5cjnK=9*G1=HLOZU#KL0$S6hTg8d)yVv2@)?b*5(6xJ{;cHWzhR`>sBB|G-NT^q3%*VdBxdSu(L?2k8k#|eRl;H=SwyzIrRL}#p>gH7Ml z8LwzwyQ!k-_E%N`f5R4`qMp+(T)0< ze-1X=-FduU7U8>{UtN{3EJM26=4tRaNlz%97`G85)wq?%avDy^lXF|2$wZbL_7cD; z-(1sy7j`aK*3(RTnJj5GNvt$@P8piG6K^?w6FgJG#Qts-C_kcyS<#=FkBlHrmL^>` zM3!_K>yI8Sn6jLxVC3w^fxf7jZX7FDZ^B0wZ{=+>+1)SWeuD4P~WrtcxRs- zE3vu3i_t5t{GBoFf@SL0sLnbIH^!{+$VYpdHp-I%|0raj(^(*ywoSqOt z(gPIj&x-)666585LAygIko4e!=A?fA)`M^6;$q)Q8eotKq^HRON8!+-;p+6p<+h7` zGEPWysD{e{Qb?z7w+GqhCb>2RHAwSmovhn9BZSN?B3;b3yVm?41twRc*B(!DsiZy- zSxHJF09RHkX|w?rJ1@-pWRl0m-B;by^GvsG@uNuuJjrI#cv4gXV2r#KUEEvtk4F}J z$~?soon^h(wKOUn7+G4KABdh@098M|;D>0}HIG&Od|~9f59dRCcMn_}4yw|9KkRB@ zUB6+YT^9|vp6=Z6+?t|RdUX;%4xO10DmG0ZaoIMbuqce&iPtv?zT;EfP~(%{AhGIf zGg5T+@ZjQP-dFFabD?-O(*Q8EBD*xF^-qO z)9fl9*}G{eou5`4@7#yW&Dw61))(|^2pZ3BaqWF!_F)BkmUl=RaRHsSr5I}L<7wJ@ zq<7Tsc5sOFwnC@^lV$!0fVF95%J>a5<2AI_`0=)tYgiz`Ds{mKW@2s_`yW|p25o37#FD!D5ctA@2X z+mi_MxoOC^72n+@QK|tMST+UfmC90XDk7F@$~ii)cy**ovuemWV+bTjL-eRfm??}j z%t7jbfm}BiG*;64@_p{;M=eA6rt5|2g2IO~&8kWB+%!j+z}iF$Zt6;*Ora(kKzK*< zv<%DJ#>Iaul0fUve9z2ArU%A&I2~WFpUk-~U38jtjy^^v>80~s&r)jId>f7aSvJ#C zP;4H&&+#6P|^eS(YXx+||WKR#LW2%7dknjlP<)?4NIOE6hD-7dz ze9x_}a%i|1e+Ssxim#q1;Zgm<n7LGK6z@r;H#!5YXouYMCs#Mo?(U61SwZr>bIW!dEE@>tEpd{|RjwZ24GQ0K;p$r_sj zn5bdH?ICX81kTA@#@_S9GPM?;9?yvtZ=OL1+_ET8w6r@#;`huUW5lT1b+J*Ku=vqy ze7WOW9qP$Z`HLZu#|IhrGE?4uj5Z?E&^aBhwd#G@RqI@&TC95w*E>XSKW3T_xI4Xx z`0;=SbE-Ce^gcMNLT`i;t_J{`Qt^0WSzFqQg++=x!r68NDJrq&w|;I`6HGp`vv)Kj_^8e>-ES|wq-W*E6&#(i-z$f0k5YE|)= z>fz9C->@2M?TkAIM2>Ee0lrF;=|*q#^usgg@B}5|BtC7uvwuN9%_3FxJJaF5dq(`E z{h1O1vY_X3WX7C)mOlpWqIDG12vT8fRv_kJ-Ra*nEu=7|ZOzBpTNyf{Q90&pjcVKU zgXj*em~m{UxR_!$^Lwt7emliX{5o+2+7S%}Ogr$+sh)=5D#(;^q_c0g#EEIxQ2mid62C2s z#$lcqz;vOa2*farUY(~FkAW4^*2XseX45woq$jEobrin8K0KnyW3aGM==0cl^?PQ+ z+1{MWaJrdH-Ore?y(ER-C8pPKpfY#*36QMN85wH&-Pu4*6C*=8#q zHC7?RJ31~s?tv4LnJ9YCiv!e>i*t4%v>)$%(GS$s;^hXVqE|Kj31jKg?EmcFfjH&% z!i{{+TMVX)?X<))e^9bCZ-y9Nz*p@GW`+@6@`fwJ%(`nB3p~5FJ%`n}FHvClxUW(9 z=1$_bNK~)54J);4GQb56d+m+g7_YNSGtfTyhcy+dSN%CP?cTu79jA8&S|!QiV%ug^ zr~L}U4Q6-;tUJ|~$H|b+hjD%0BqcBg7<^-3L>dylzMH4PWzeZBzLSSEYW#vLU{|6u z_e4+y=ly7|YS{FSkm zlG;;8PIo`PdY_&VCQw!Mg(yQ#E${CA=y${y0 zB|>zZ3pL<)y&`SoC59%QS_dOWC&lKgf;}pdMf$c~(>lK29s1gnAyg8_Y%6Yd5tQps z1d&T;%LD6g8PCmf4b}T_6RC{6nb+Orcwx(j+p+k!Mfx=B2ZJfN^VuGRKvZ1aze|sa z3tCt1(;g?op`)LDj@GO4e8ks3yf?s}z#Fj0x)(21An}bcmWpBg!eo={b5TK~VB@9O zXE&cbjEefOv7x;D9Q=v-&lVdSH(Oo$;Z1CiAtCF-(;3gH7S~1JCCA?G6g{WSnoS>C z+82w-G)*)Xo35qss#u#5F|&&L_Tg%0Z`9EcLC67GY)2I(uSTcxup>U#uO3v-W7b7j zKZ#49oTwkbW@Ob2g+zayL_IHm+nrY8r)w<5#i5@6 zZoWI|vEw3=z0WX&B=Egxd7kNjn3Ui@;+}F+4))4VU*T?Omi^p%v z#Exm0ZXG^UV{%e;KmA!wE#WD9d&rWjjevF_&^qNGc(!ST8b`bH&mqIkMebZP+5MOD zI(3kg>p^^R4Z^P9Lg`1z6H-eagX$ZME3&416W{28aWN^|mI*kw27%7+IEgleHYb_* zSi3*4XXaX)f4{g#Ku0ZHVbmsE;UjLbLp`wF&X&;7ZVr(U?v#16f)5r#YKkYd=@zWn z;b_)u!0X1UxPDxZLBZUl*!Fav1Ch!6(uq)4`IWZ8+oL$QWb|BYmqoT7QmdDR`!>_T zD3O%pM%J-#pEs#}L43VGj&Y-l9m|dBd7_5IUa3_z&lPI6Z>#4Hg%cV_5GNC(UOZg< zDQ+Q+fFbn?=*;hs6Hpb9I}@Q0$d4>7j!SXNoLX{i(V5TJ*l`PCtxAKl1ea%3ds&qB zVuG2N6$)dAC%<0XA6BY<_UQW;5&Q5$4fW{jw&9P*;2H`rh=8>fwS6$nbVDH%$Yb=} zWy|wHucKsI6mM40yeX-+Jg<`7_I(*soJ=FB^Pk&5iGs@(x=F)@e zC~oVEN>=@LeqQcBTapKHWE1}wU@HdPLOEHx_t=f!jtOdL|A&~E z{0$LZRc|uASz;BA2w@M%gTt%Gp)=saMns0^rK@taCdqXpaoA=ZX2vi6?QDO0Zbzqg zO4kVOpGFpJ5em^P(%6ja*kKqSnl63a;)8nGxJ>$^eY$Sq$N1;N0#5NIM?ixUPFt18gYPd}B; z>dRG{pUDgc;T_U^?MB#5u2{YEYIJa@&OMWFA7Wy|3T%G%viii{x@T;OKOP1VKXz5s z>JMSem2aFl%GPV@N#xBEcRYQD?A>uMWH#EsQmavRCp(Ss-cRxJ{=pf1#HJ2Q&gwe) z0$XS^X_77aSe~W2>S3LIjsC86r6mXK;j5&(j&P?$2DQ(y_$~<9+bPiBjyWw&Z8!MwQ*ghN+Dsw?D@~6M$=mAF;`jc?loya z&oQ>PAo10qowbV2eQXP1q=C2$t)ve_ARqZ~HQ|Mq?9)A}{9&G?187WWpy9;ZKnqEz zn%0=0CF^6B8=3MDoH*pk+c#YSR>O4uHs_LaR$B6Gxew?V28O!a4+6kZ|NRfT!s zb|lQUJeV0h$|T{_RJ1zNv7vcPzcuI2{uXj;^{Fyt`o_+URC*%FT-uPdAPzNL)KxVd zhg?fkMbh^nY@{fAwq*owF<2Wm^Wa`})R(IyvmMb33e3H2&PK^K0xs{88{Fh>qnD{l zn8y>h=k{8aWHv=OVMD7Zh2En?G0RW$?NNrb5)v&bY3x_~XSWIY$h$=l*$wnxLo{beeW4K!q4bTE9#qVaZ8S@_XTPaFv|T{(4vLQux!}(X@G&O)wmynR zO0Kq~IhnafBmI2WkJC65q36b#km(ww{%kp3qTw3~2!-5>?=nkViOg?g(DC%%GP1{& z%h?B?Y3bu^4;vcV(O+lMhByew?FA$Xi!{w}h<1Pvhb??n=>3NTnmayTmP-zDAGW-` zrwmO`s#TTXh82YGpQ^lP<|Mb)xe;BWbMjO-=Zz7qdS8#KL}wgFDY>n^m_dW(o2Oz< zQ~Z>IE@d}1_!K2%;~q&P;=;#eBb6J@YLFgGb~jNE{QFp=KiAHUHqXIamUYE5@t3Z* zgiy(_5coP|#z6g?b+T{AK6--qdC&Nx1ciY7bDhQ)f!#)8Dim>uX0#%rH_?YBfz-l5 zdsyPDXIE6}<00othon~P{5SzQ1gLO@V{$c&;)7V#M*A_@>X0sUb^_P&;{3|;P>EJM z167Dr)>j6vvkhlQbhi!}?u?6|b2q>E9oy!;tNoP*>dq7=JBI}8yF@E5yfaLCsY+L~ z31(l)Nupow?WsC0e_ajfb#8-dF;~RZSaN>MszOUGpnZnxm+|?D^Ga5gx%=}2Z+&()L5+pm)sOpKTan|}c zr;xwHJoE*BqclvAYmk3QDXEl-;ughwW|bTUs#fWDkvCEqfA1ee-NA-jf#ZTknO_&p z1jESMXnOE`PLsfjp8>a zS1m?ou&%rf{3=>Gu)Fk+VoHuh-^^2;3J}&U358Tc`L#fzXu=n@zS|6ERrIij_JvwC zF^YQSYF!mxdu%n%2N7RA#4BHw53UZ0eZwN&o*@aa&XzA2FCV%hEm$~Gw$x=|6}=8} zh-q-bgzZ{lxoup3_Hg*b)maLXwY)geC`GYyWUP@|%{_N$@U9@|X`3rg-X=Y_%7$JBZ?se~XgrX3$NXvIC4BGVqk%y+2C@1nz> zn~M4F2)TpmGcB5?p2wT&!AGOzFC#ycs42EI4aoEq?Bf;9{TL2~wMLxfghC9)4~Vpv zbi|*T?rz?e@ZPD^wI<_Ke|{YCvoPUoVaaIF|Hg(QzV2ZBWrJ%4T?=jat2h%1FORmB z{aqJIZG^!7JRoVXIG$)U@YJA7x@%ESOF-}vpPY@XK;zk}$aol`V{T{Wuvwu?Q4E;N z&!&r$1IZ&0ch2@1`iH(Gv&wcV?h)nn1|}FPwZoh%YX+UbXGB0Y6t$RV#XHf4UsxkP zy*_+F(WR>&$QN*Rbx8rK2H(5bWO36%P#^01oa_sct=sC?(YF!MXmttA+qIRidkDsm zJuc(EFK!SFP8pd&L({L-&zf|v9HDI zI>WgCJ|2@+V{-ZvL^p~r3B6j-QXoPddEXRCbU(L@0!4S6u#oAUNx%V1sXtAr-&x@_ z=7hb3IP*6`4y_xWv+HWKWiT!gtR`m|<@&6YNZ6vaLukYh5$WrmhaQW^oJw<5OIHIY|yR?c4jytW*bUeGf9ZgumUXGMx27Xub&! z(0}1KO=X)PMo6O>4Skk0++W<+RnP0zy!P#e)H0riWGt0EH2qp*-g@NpH2JB=y-XP`he0w0 z!Edqo)&gX$Pn|ci{OSGaVJO2DKPou~k$?O%J&+0a-rOvimb$H(pW%EX{r*QRX~c^H zXzz!ieg-Nc@3G2=HL*juZb`HAZf=Fe^Lh@$tgP-dgz&c*eA`sYezUvXXHm9ZS?iD}F$x!?#sA+c()Hr6RyjvfcbEK1U z95ego!+rZFk_|+czi5HXuL9zp(NNS1vqudU4_}^Pevo3Sh$=5s7YfU{=_9t(YP%H}I5J<>FN+Lz2bu2Njd@72~6{rcJq!HRf0Jxbd9YiO})G zz}macRNplS&)cvFZ4;`QUUb!Rd+<;deYF>in&Ht!IO+^6spTqgFYIc4uZW@>RH70y zpm?u+_k3%e&_m=hkIv8j(h5bY-C4)W$${vHo%8oIHgUY|CQIP-7H!SOb4i)gKWE3L z3O4km$iGmrYzr?0L$G0z9&C1?4=@s;E_NKfUa|<5FKO0gY^wf5(vInN=)mTP@Q9 zufmVU-xw<^6L*-Wa$fjm{yC9Xt82+dQY-yVJl6999l73@Z{9YYj8DRBQM&5U>wTGR zIcFYjn=c947;%hRlfM8mPlFrWN}?9AX3YV6sV_ifR9m%@Gsn{tq`32t1YRQjd31Mx z(1HzCGo1Csn{%|a=>-w$oxJbqe29LHi;^0Xa?uw7e}_6YZ2oEv!v-INN%|Q<>wr}u zJGt&#_x8N_PVDc<$UgAqaG3UTAcsS7RvYZ8W-lZ!m9LL*YB6mmkq~2J%h66Rq?njtl#7^XuipvVJExSGCv6kr_tomv8GCZZV9P$96G0 z_+vBlO3yZu)q0PPaE>ijEcIT^FuMCsD5P0?bt4wF!$8&ZDCxeFKjg&Grak)~Gx-m( z2&Cvw?$n?@4(~ZKPo#sE^Q_Fp7KsIFGa~tbPwViUK@U25O?CMtd9yR}C)2~Wd z!YBB)NjTN|mn`4{(C?2*rmI;pS;IeQtj4WwlCa3K@v$NY@bKTg03^19HOYj#&P7}X zaP8f2KRgQi?MLhX3@vPxzi~W2>DrO6)OrN`3_iD6mGuX_q_S*K<@w3S+t^s_?nj&Bm8QGXwf7u@-qcbcb&G^%>6v6dN=Y*%Cj>e5mAR#e zIj7Ph!e8h!`RReQhQ-hWL{6`CIJt&F#S0ztfIdD$kRoddO=G3q5ZxLsT;Qvw#1j57 z%c67Ej>=mTbf^x8poG`fGmpyU^bG?XqC@*pQGJEt+y2n#S}U+B!OFBcZeZa!A}5^gO>OL+utoG-uuVvWDA7Q0!RJm^G?uOL{R_DGnY%7~8K z1=24!dFZUX-MVTW9frjHx&!Qg9bL)yU;&Z;UKHbjJkmQism%b7@au2(faK`fK z(`n_5%>D%NgM-PYp&^p(@q>HXgTV8p5psSZ=1Np)J5|jdLc&V&_7(lyx3guc?fAAL}hivXrp1VTQ;MvD%j(P6l)ArMM*5& zIviFKhf7zA7HrQHj|- zW)ix~d{{S8qk;_bbU&p9le5vi4ZefL@9EueKm)0@=5oa{bibTz55XIRrSni0aAUxl zXSOfaAzD7j`e9R(erj)xBUZ(F$P1Icc`L2zpUa_;;XIi*2I1d5!xjHW&o@zCBZ!YA zxKWp+1re8fi9WS2QbxRE{*39b?7E4I+~**vUk23?FsPJ1Y)6CKoIpAP)*f*1Jp6h)acc&Ev7zVD3fuz_h$2|MUr#1gjKaRkSD2;YLYCk$$SP zj1sy;5vmf6ylUyiyoC|Ysi-NgfBf&R2L4A^RaMfY?rys|o0s@7@*yc>B;qrC!2ZC5 zpeD+10Q4gg9!jVi0>;ld#L`6q0rCLb@i@DE(JJFzBuGI?H`*#$9Tun4Fy-o(fSD=o zDDe;BZ!ADas^s`nTc=Yg*ND9%d*4e*k-9M|a~&0_Eco$Uxi9(B!u4pWFU61d8lDsu zF+GvkQ|jesj;2oxXDS4-g>B13qNIM1;2^dm!^S^D3%f;dwF57;(H9P8ZFZ)|LAcwz z*6MD84i7KCUN(peBO4RvQakqHA0T{tACC!S`Gg#7l(R@49BxBN-YQ}>5Bt_shj$OR zR?7KdOZm+ajX!(>1QW(~Pr?ot=meK)IM?^HVpsQHm1FHl~5o}Fi4BsQVHWRsAZCgq>Tm0c8~f;t{W5VAYm7XN0iju zJOG$|W~asOnEPh(@l#da%OHdH&hct#7sye}1hPllwk;`QN8XW)1DUtwLEdX7%_8ld z`kj8hBq6&v`B;vQYKOTjleP%@CSKKy0022>-2tQld63QCz#J0Ceo{}trTr|C---eM zHve0D5zP|){2SDPjtiX~ASD_jHLKB<#ATK6$5XxR>RS|85@aaAJYzZhn!@CpfBpk` zO&al|*fZ|I*6;NcSY-F9ys`ZXc~^o2C(SSS*zS=1sEssLcS4jA_VC*>$l07M={q$S z*yMeJI6d9%jS;dljGqk>9tFiZBXRt+!$9sCBge(=W!Mk0zDED6T5fH` zukY{f{v2edr5J7^+h63?&mNX|x9}&BQP}g)kF~m~9eA@UiGz zxrEc8WAn<7hT6u%Uq#dJ*UBet7WP{s$*E?6o?Ft#RQp6bk!}>(pB?o5)^WCR)u34W zm+HOdeMgp=%QdIxDqkIzDWwlcxt9vLUCWW<0b-iz{6?j(e|~Jbysv{tmhD#%+5gYv{M|%Z;7eqB1o^k6fINNJnGOHX1R1X3U zmtGwT4LJhwx?4r!=-)udKtuM)#ao)4i5wiIaG3LDk%=m&RFj_PRG_@#B3PmHvZli; z+hj#e5JF<@(JCjqwkXP*K$rsX?=3D^Fi&pCfBgu$_3?tjh{wJYep(-IRo^j?4Lq!R zc`eUkGL$C^fQi5uK5R_4J&~MIZi=E}za(nYc^eo{b&8;A4SXnID#9m30;t9oEEVeK zm4@je$N${c~K)!zf2FcfnA0T1sLJc!K`N3jX2DVRSv zi%U%7oUmE`B>Ld(*`XH;5PkVI!UB4e(Uj9k;eNQY{QVl( zhcEKUq;`-_qD;h<2n&OCZy>&@r1cS4@fr z8<7@F3&oa8q6Bx%6tlaTOzB(?TyIeZQ8KO5v(g(zRuIBC5RPhUdB34} zCpZhAn?FB!wu(EXQRSJ@%oKHqidlU9XkZP0J5f1@@6=hEOhd*bRBYE~#*DHCRY(T$ z>gMVRtOjE&TZuIg45KR>ls#81Xj7o)Tj5o5CA8y2rsNOTl)G=xN+=R*q5myRVcM(-)fiRe-40m`(C%o!pZm<^yv z5Vb@_#yG~@_-x)LG_D0}-4Tl!iXQ1t;jCb3FxBK{Uc|<9^~0d2uFbjxMl2K|%yk}$ zxEET9_oNBK2wJ<7x$HoX4d_QDg>W&=hqOH7Ih+CUR;`r$@&V`i2D{c627A&3v&14g zsP*26I%>r+_DKX*nlr`3*2vrJ$CygOD!cJ&4XmRAP+p_`3pQay$t?0E8A#@a2cVK` z@Qe+`c33;|XbCjrf76gQT+*)3dq^WWn~#^eo8OGM z*T;mb{uV_GF7x;WwGAJGd@sE8WugvB#^!Fshx~G`hyE&lqMoXgUFs~UJaxQQkPuRK zpG?;u+^sgw9si2R7cVK9M12-(=MYqDIy*gjRx#{Redh0fZqp#ZKF+k)sGL|u`>Dhm zGHcBcA_jp;&RwgV5O>Gjl@SBmHm{5qRuiFyV?A9wE&|lWDaA9 zP(+{P2V>q#fKHqY5lJL7baG@jNJSlj39CSp}~3n z`2==1q?Ssc47m@w)S4{ZkC_G4xC)TOG#zj7(CB5;(dOuuMV zpTCfcihM=`vV|{ltcA$tH>z3#sQp5_qr7>G5#>@EikK3(C$+P?npH}&Wv-k7#UPO%-Yxl)J`NNd97R!9j3Y<{=I+BpajLuy z7_5|Xu`;83nu`KMa>C?FP429x2!!xl#G=_XW0XB-+6O)`@bV&XoZZt77^0YOj-Ncf zVJRP}OYL}9xu2toBTX_6j$PCjJ$xj7P|Gk9zu7jRUuJxGfsZ#%BbI((6Ok&gA`=Ws zx9h)kpuAIE%xQp zcg_+PWu}EcWeK{8GKFS@gdFfoy-NsrCnL4ODS^b%va@|bj3DV-e$yH{@9$bya(H=x zB+xQl<$@W;1~@BZ7%n2@?m5}J>Lrax7ffJih#!;Gez{@=uzi5& zU%%K#`m=y^koRMbZK`5Rk@Pk>*f#6&m?B;^8)Kr*sW`{tjDWSr$@0!W-bJy&P`%RS zLJ{AXd*Wo1PwXh79LYpUmzm+=uNy!C`f?*M(x`-6#R)*xr1$`rP4A`G>x3Ejx#%Jj z^A;VgTBp?zVqlI&u?;=5a&j>V=a~PN(be*-rT7tLA^7wlsvtCpjbMVe(5T3;N{FNs z8q*H7A&Yr{b!+z@m1=*)epKMKBx0$d)Z))J!#}@3(GOz&c?EvOgRqb{-Hbz2$skJ| z&Q7dIAv5e0q9fXS8rqMBpId;*4GU%wa@&DQ#mcU@kYF zi;We8bVqr?D7I}@#+5&}2_?MP;Us>#T0doD+aPZLQ{UPkm7^XcTYE8*ue$Wpx?x}GX$i-Jzv~(Tru$QHYldok{Y&nYwz-Sc5h#H zo-89>**2U(VTcQ%YV{me@0eM0`rTAQ5gYPR7a4(P!aEP+E<%X@?>^*k3M7lEUNGBloR zl|vD3voRYnJ2Ya>6FQ6begIAh?SM&$lxrEKvQYErh>+(WQpAp`D$9!SMOrz)Tx6^( zX`$(0eWRcrUXlu&Ns2Qe3)JctimS0c5mu3Ym!vLFM-usHQ)jhhq_$MzdQ&L9W27Qw zXt29&nO?&boyYw**lpJC;Ddu+FZU9 zLEo(QXYeZ!{$wwTbHV(!{7N$lF{|J#ZIL63e&2MmJ0>@PBROB%vuK*J50-{Sz7QJ{ zj&(>yWc-0K>)4xw4ti+Olee7VVQB6X()r@29@>Wlg$497#B1h zMkq4Cnkq_`Ukd7NkvxBF~NVr5S=enh9S>DCxFmW_)HA z6`AfSK{75QdMy-3^igKEf=HYgF@o(6Q6>=uidsX#8J5%8ve=gSQqFi5tresq`n8}* zm3GMpav$?3a?&Q1W|D6VD0@7Fb0WOq&W1;E)Ef_CQW3AMo)Gg(#T3aXQ9{qpCnFY- zg{bq)+Ed5QUuC7DlBs*2U?QU_(PNElYxxMp7N)GBeU{>j(~rpbrM~9yqp~=^hLZeZ zUVC+cuqXN+j!wX;Ds*4eK`U^^7ww*(fIEV?Li725QJA$*t+q^D`TS*@yr^@bN5Vgg$+J>|7QDwI4L!bioR=II`nXAvN zJQFH2kQB-kZ7iCd%5!GrXM2w`LC^+K9PCb?9<^)M4GE>@xH#)_q5V#b3j$hB7_PLx zr~-=^k!3vjx2?@`k#!t zGTMHV#9dk`#%NJ+Si;*J&lVjrl!0NLTg_0Ev&boLq9 zGz79uTG!u_PwC~5%@yY^Dlpxb+veXXpuEzmK;(6443XO%^_yBCf8$3NFFOSyi5kw6 zL5Lfr$oh@AZS{@n2K@X5INJ1CJX%TRdMhib)}~d^^aVw<{RkVWRa4}ih>`?!J|>Ev zdn^2aT%fO&3^8IAN<9G=i4C!e6vxGLxrs=Jzk*4zaUKto)a|K`xB6AbkXvWIf510n zp-pbaS%?yP44GgSQN@|jd&zNeAjmik^Urzob(E_|HHfHRhC5G@_h+<)$rgY3Y)|om znAsP}VMVq{y{-9@oQ7gudA#*q(8i^^AO!PvdHg+@hb<`jE-seJ;S|N(CCnM*(x3lG z)JV&GO2BfCebN#3m~Iu$7=yJ?JBso?KVj#Q6H+vrw0v=^PnmBjS$w2dU21<$?Rxm= zqe_Sf^AwjAFu7MyrLK`{GJFs>(8vYn1|`YLD;&qb#8!Meq3=DTp9B|~Rzw>4BBnpPe-7Rd?IdC4@XB@#C4~h=2f@>!*_2UI zPUxGNs~N1dXiO~~ovB`u2u9LJu&sxVw{7vL(|fM7a$o6C-s}^nzm0-&6UA~J*2iSp ze+rOD$1X-?LECW~TkiZ7?@|;G1qWLutzt=0K1QOP?Y5NU1bTxvq?v(wR97dl;K=Fs zFaZCuSOh;>TsNQA`iL-z+Ap$$&s(?HmQRBKbCG%a+NXxfx{gl8Ar>rITkRNpqb3wi zOn$fb{U$~62vXhZG`>$GhWAl!+d4%X^npG?31C`_Qi^r8*1e#kg9Iw#8{-hYOyWmH zW!ugI7+Ps0jp0*f$AX?Fr>1pD#8vctW#zA--3u3<9i!86Z3L=x`xt#2HY+ERA`u6| z`J&SEgVg0zUC3u{(>~50iGIh&iLV4xb!yY4Qz6d~)f?`x%0{m9y?}8rQZFcPs(G}4 zeMdG+QCvoYtWa5rsOITEh+-y#e6G3t;!*iflrLG&5|y;kw(HNmahB-f-bg#~1q#UT ze{%u&rqa5jR2Xx9=-IYeQE!`}U6EGFF%VXH)=nxi#em@s9Fz%C`Up|Akv#%#syT>x z3TKv;3(O0w84TZC96tDLNK9xbK0gLfNQ+B40S6U0h6QejAln+58?JQosjFmAxA6%7 z71Rp^jtkiER;XlnJXTb6v2PF>Caw6vQhB$F)F+wJ+qq7Tl?iDj7J5qJOA5>QI-5$Wz0qa%WyM4}DF(D0JkSHxHiHyrcBTd$bn)Nn1mKTQziO8W15 z?#db+42VUfah^J|K);t+{kd<-$t&)Ud7akOoq!f)Uk@j5gqXhg_Nkt%>{vz>4r0ov zt=iqWPY0nb@S#co4I6lVmT?|Oj8J@(3Kj}C*q(QvO)`)*ePbr{(1ivPk^@~~#5i;y zI?lbP{(fzxcX(3Ou{qqjk^GA}JYZNdVw>2Nd@Af_=T_6^&l6U^eG_RJ2IvZ>Kl{=9!uMy9^| z17dR;uHfPwET9(FXWf=_y1Z20&a~<2xlG=Zt4qihF5{ly5Ui1X(adzZl-vf6CSrwT z^&y5=J{`o}CktpZfS!x~lQ8r%x7}tK-S9CiZmOK)?>E$m28G~vXsyM^AiZUKE=5eMKXWaY<(?&i5PbAMnWUI3 z&hWYOA`xcVI46*%$=2*3DNzqWWj(o*YDX#s#S;eaouuTwdMKty3nb@uEho$EuW&0v z@jp{!Ac0YOeO4ZF4dh&KMx5k`c355ay|S$jpDqUlWWi9W1V`?+gASwrdH|t_&v=N> z&@tjO9$`;Ixo~nL>i(?iTK9jg>gg36*80KjO6ECvAHBf> z-B-q!z2+&v)Km{<2-9Lk@KaN2#@Lq)KCgy~7H$SLv#mGsom)7s)N~b^tP^q~Ew&`? z8q71*9oyHqrir-1yYFhyUy$#r)-AnAggfo z49NQ?f46wh=irA?1xa;cZDi6P%Uac7YzF$__&YAZ1NfXL z<@i}Sh4<>EaAL~AGqovy&;v{1X|8xO3u+9A@Tij7fn-%O5Lvb;)XPKx3391~JD}^| zAW$He4L)rY|FD^fH~dQ8(|~Cs?&I_MxQRK~9dox^ z{QNux^eMzA_NGp(r~t3$)2RDMMZqq~4*yt4)7Ihu*TGDjFzT_VJGy^m{N1fzLnUmb z25Zx8(FSPY>(4^wSpq!(_>&eSS<-FGc#!nXI)6r=VT}ZZ#w8)mI{>WKY$#=?WDF(tz;R;Hm!*|Z(KdN zD=_6DkvNb}_WnK#1Z7i!#O)QJgbZ`z!Pcl5AMUM}$xjb+>f|ZD*N@sHSz||FMVub7 zNtfPR8#gZbiQJM+4F<3qqZ`dtj-TzeP1?z*n;_knDS`>xIs!RlZ^*sleNUFl{#8St zK#xllP~ZK&6-F~|H!Vf@vB`>xe&o1vcW-6nMVz;Q`m}Wdff4^sFnk*w4$#V$%@&?w zTj~J|+y%JMr^oM{g0Ro$7unW(q#0)!A3Y1%k7-KmQ;l)X(3(2SFq-lOw(~6k<}!s* zzy-7pG)@DevXS$HfJxEBRCAuE}&WJEzD zXA{s$BD8R~ZNq;v^7_F&8gRA)@J?Gc6vMbs`K_@xxA>gK>&o>jrimZwOoG02KPx{e zeu?CD^C!J#>-_a~@?3A`Ex|6u{VBqhIc6w5?>0KAo)#o3Vn`0!)HxVAGj>kFt~wMK z$OS2^qB!g9^)L-3m~;dX$XnV5AZ|KfHhcgH`fq|ntFTM@(-~vDt!Dy2t5TzNxbeGNcv48BprZXG7#z7w{?~!^qjmQn^V4{ zJ{g*|2fDDmB)5Fr(o0L6MTVH5N#a9)Z6PBX5`+)m&fBCu0YjwAz+6LA4h*7Twz zl03JoPz#*j;f1DNAY?Ge+Vq-lpeC;Zoq@Xnlfhs|s-3IKu1aM`wsIX5Ji7eQ=$j{J zwj3v|0qD6zCAj*28IYcM;x5pBFUAGC($+`-NqfuOjI6-Kw~yShibyDZNXPfKJ1WRM zP3asLW?AT$pv4+?z5yuQZhqS9&CCy`4GD!jI*;=o*wwsIIX?IlROKYULqYgyMHoh` z%!!Sz!r-!?<1mFmkU*m-O{XgJ&<%VRy>>Fllt^^tk{O^o+Jd0)uj zDYXTeeKMtLT1f}dA>;T{k{Eb z2}`?;Uss{bmyS?|&mk_;2s2J$4?Ao{!Qy9~d(F|-O4Q0R-ho*PlPKq$UQ+uR_?ZWY zOEYECwEiA!vp;SFx)(%^oAEeX3tm-kyw%Y-M0AiSP5km;wM!CSI9#8;#3o|~&P6L3&FE0H-$wT*} zrvzO}mUIsOr&o|D5W289_RW`7e+WnpsfGij8T6=RI99=-j-D7G&mt2qy9yv%08J7#@Q& z5O`c3(usn^Iieeo&w4vEbr6(Tq|~?xAW_Gepugej*TFGT|GBgIU+5eHih8ODBlCcm z&M`m9gQ}() zM^W%6NEN6~woG|!1NN?3u!g1Um%@Zcki|W;ty)C*)a={O;Y1fl1$sh)sCD(K++Vud zZ$g-yUOJ<5RCn5DSZw0TUw5It7iY(e4Gq#+GxrdRlCQl*FSg1Nk=u@mkKj6cX*-9^ z>4*4-?TGm2NEVD+`_5YK6R;cc2>LW@j%~DJGi|O=h8H#liVj0{%xQUdeZUN-kw(l7~e&RvlTei zwh3d7&{F8HG?$+<5v)*!06-dWb;L}9Zl;Ke;0u3pc|tZrs_$VbYkw*K+d|1 zojsg^LheI6h2^Pmn_PcJ$m|01y<{nM{>P9wD=j(?^e@~msx)VGyY=>?6TW0EYa5 z#TW^+NrJ;K%CwqO#KGZ6j}b!4uaw1PNT^k)x^B%`HO2o}YKSOh+N~%JvI`hUOc0@T zUj!-Lok>m02p-~b+~yHbc+eqoL?=(n#+WTq%VZ&}Q{*8hj5bi1&yG3TMHMYpxv$I? ztpV?Zj9ES4Wt))Z=j%j{orY?N_7J7ZymWeSR0fbjdr}+z3eF(-6!wuy$WN9|q>H|? zV+NppShPZKP(v6aG4Fz(=6nD-eeuf?=MAH+qPAwowTp=Gm)RRe`=G2-WlHjh%37Rq zm1go@FM01g`84T{{p5r2*pi0Vn0ITA7NK4Sed;waK{n+}b^vx7@5+cuz-q)XE{U=N zWHZiC+=5Bi;&zeo98DH=>DJw{>J8UmuBj_$ziY)rl(?saAjK@@n9!@{#gbgK;M(yA zZP3OsK!xI5w0k|gC4`hih=}{GM=R{&=85F79>U>SjBewt#;(}3mVEj&5p;WN7%WbA zWlMEjtlv_MY_{9r+YJ>;UokzjD`C5A45=pWfbTxBpE%U#Swhy$^6Sy}Bh}(Lb4RY> zFbU8bH016pbF+Q|xflkRv*^E3tu_M+YvRFZ4TAo>urZDZD!oNcEaz0VLd=C|Dt5Bz z^($)SqwsxH+w10AwNd`8}H8-$%OI~=vtlex0SX`w+i1w&dMA|9(8@sszE zma%uS&4VSH0Q9XJvg>(L6COm(Nosgr#)UoXHk$F`{`=KThs9USmL9S(m=DBzMwNG4 z?Y@9)#r+r!+3^oZ)@K|n4lJW_E{_98IvT4^Il+Nz21FZPJ`{T92z+w#|o^k2uuD&1A)w)5WRI%8+>A3Yibz7@v9w z3Aq_{og>_Q+E>;(3-;{v*r_H6x0l`Rtb(+0?9ft^auz3W7_%P+qyxq%3)JnQi#@5; z^O)@stW4*K-o8z+`~?9aY)D7wlU->Wm8$TfjPm6KA`!;VkD$EeyFJlD!n)4x%f2r*X=rx8e8mZq7lQz@!aEkDh$^zZ33muAuDk2Ti z@i}+dD!P~}h*T>TvQ3Ir@>McbN}(f>Vx?&*GX9G(a{JO|V$-=4-yQS<5KM>sjM4TG zp<9TBc*54zyl|xb4nO)$%Uub9IC^*=qSMoj6{do&wx%QS6Blss@Of!0%c(kM>Tf+j zL2mSzjZwJZ%9ER%Y^6jpxB6sK`}AzXB{@4q1)gEv5Z^eriLrq|dD?Y7swolkB83sY z4HqIv!PzOa+|AWrOvs0(u@-f64s@t)>ce9kzsQ3yfh3%ZY>a8U+eD2d>9?;1ubUqT z9c{5WrzypMXP%@;=XIDjTC1jNT3v?)3nL!=GN6OH*j)vEh~Fy z)0q`3=KEHP9e?1iT<$4`g!dK#K3$POkWAXj4D)KcOKd$1_`K|CbhV+ZX*G|w!WD2P3^Pn-mpN-PAG$Hi(;5t`(vk^UJw zO2U*`x|_?ip8Q|n8DV6I=(TVoC?THzzW)>eO8C$vNChC3zD1C=TziMuX~t{W3Efuu zAwvq7S}Nd(moNVp(zL1!*xrKI@5CN(L!4;=0@&wo_ofy6E#Ybe44VE2|6~6z{7?Qm zzLencQ)@*X+*3TVR%u~;3C|qowle+$XI1~7FNvwZjP@@y=QX?a4VK^Lwo*mj=-I#d z&hBBzLfZ!)JY+frh15J?dN$Ii`d{cA6+J)%p#_B**J!c4p%*^+pPrPK*0V2p0_UTNQ;_BGBQC-{m4H;8EYx;$ zQ2TvA4i}Fr?-gh$-lCs8Qc~IlaZyOG*mk0{;YsUjmhO z@aryEJ%b7)4HDgOUUK z5flJ1c_GrI*Ihj5ke1Zi;AIeF*WRjhQoRfj70{BK6JRCjp0k~PvXYwD@R;?fuGJJS z^yxc*1s}QdW{QJBe8gMA$XmOXuuy`ID67|YX#=?gW0%KZq>Xw=5YC7JD&8BD>bKdK z9~f=~6XS;smx6xEy_||Jw{|M$P`Ho0@j{bQ3ML>7%0syE^#C6jaC17qkJGb_KkxD- zd_pi5_-yFh_NpKhu&LA8<$b!rwk`ndH#}Tf}`r0l-HIE2bJ???5(+{jG?75lrniU zzrS0i1Kd`9u{ReGx4Qw68>|6QVdU^TuU8T-!`GR7ckys59_4-mbmVmr@d0`AzWM<0 zvSW=7lmk?^SG`XqbwLdFVu&5!W)GH5fMU-w${RqBnE6#eS+BUabn^{hopt+h2r)zn zzXO;SD83I1$b=ndD5nV++XC44CV)Yc%&Z&R@7F%qb8A^=0h6YUM8jVKw}uC@I!eqN zYWjWu{23@K0J~0+`htQ`(|KsPco^6ofl&~lZlTw-Fb`euMHE!*a@_#HGtUnRciCH? z6qW^$_^qjms|#Q{DLv^NC!5|2nAYea>sNm?NOMzPFp8zW$bI#x#MbsFrRdDse*wF5 z0Qy##Z7VR@>D4W*dhqFGFIpSd3|HQX<{2wBZ?NcL9xu?(QxpVb;@vVmap^zC!vn1| zK3FbmP6H(WJ7iDzjnYsWvD&Zna~dC`UhsgW2~EdupCV)^0q<1ouob`z^zg)L1^*rG z!`)SWP}6ap7AuAC8BeBGp*=QFwri5uqvz zRkemczU%=cLY(Df=HykvV9y7ge$-d%W8KEh+y1egN_h_|qLD0$JuR$^zDhFne z4rg2?5~SYTP#GwGs@-xRy|dI{kbX;KU1C^7jR`=YEUxOlP+AgR!mCA z`_(a;f+hbTYU4Y2#zRADe1c@!6Tqu-#t#5S_5i00+c-5_6CR6dY$!leoL3yCFBa0^i z%>^{}8IDr1(Mu?m?^#2I`}VnVbVq7drVD6?V1qW4YhJRxh_+5B4tB_7oz^u>B0@MV~ z61%0G(~vrQO)`(#f?FxDOE@OgYeX`M)PNd@%X9Jh8@UVXz_)7?*n9*}ex+q(E0ysf z^wPOcrS5~<y4tku>03HlM2FQR~f!>sBe;U*usA8oFy$IMyE z0x|UaUS?@TetAw7cmOKrNK_VE!ng3Uo3`6XJxzur7)WqFFI|_R`;tizXPa&MaE}@9 zrg8BE#Ydcl7%v9tH)_Y-8KN&H(E*qJwj#8=u?Hz3gwRs-ecG3McpE-%+NlS%-~(M}`fM{5 z3@7JQo+|iN0Cdd+Pg72BSDDq55>M`v$^mj08u;XAIuA~g-r=)~dPH22YluON=ZG5m zNtF095~p%s!b@kvzd#(~S@44nBv>N}`lH(K{{~q>d2?6VfE2<7W#;fW zzYhA57XZf_p$_FnOu<^^ibQa6^Lrz&y65){+9^^(Wbr;pJejiAd1gpYuUO;)*VIry zh;?*AzL$zE2iV?Tt+IE4&1TaVP zj2@YZY0z#^LCmJYKAL#Md8tE%id*Uu_Wn93(+KA6z{&UZ>IVR;(^7$vX~1Rby}zS0 zwwJBJaxDMC%xfcEpcAxg9-2hOqu@~I5w&iIKOwDqxHj?L6kkN!0CH{68O zm_Zy9M-#FMo2Ld4mLct)Yrz}RktIG2 zLJl#{8TIHyRT%ktp@PGp#{!Fhavv34*MU}*g;#W`!$d;7&w6III3uk`B0+yj`%iI@ zgT8wa6+c#oC4a3VA;i51Qw)NkN08e5UnY2a`_wb2#JO; zl)q5I(sxdG)p3pttb01%UmhwUa@(Q6#}5S1y5V2o8qBE`nM4Lw>^FSp*Ee0^;x?8s ziU~YV;wU~7N_szmpo6&Vxm|mZUdZZ1B7d~;_9Y11HG@0r#Od?1P<)YzJB(skHpb&$ zANu5y!C>4Q4+=QAyM;)S1|0lM7XvFj0tAR8$rl|8^1J697+P17l-0pCAODI%V8DF%GR6wKkN$FZI<)`+Kl`nF8*@z*AErx)X?Woibr}L zj@NnS_+@_rt^#%ee5s})8Su-d4_`U|uOFNb8JBrOdWW9!USfuY=Giz{nH) z>UxT=p#SS)mlX0Hjt>TfLBsacYmI5dT)f~tjG_3j5QN~cV7z}rr-S90{3Zx~{f9+_ zClX0QLA;m>oWRmMmp^7RMG=ax#{~i3MN(RBk9ki;HQ)-L;luCJ8Sv|oX4{Pfu3+o$hw&5+AR3m$ zt&@T-vI1%NB;^bEPxy}f--Zt&^OvA7b+pEH`9H(gP16H1==}RxAc*;MgFGm81Lg1k zJ$H*BJ{NinYK;HK%w5Yh1Mw>tN5DpCHCg8%n5hU9e)=fG08tM8S0{r%Q%}3W8Vc&s z7b!zQVdU=L-yS3dIU9UdX=7afVd_EME00Cv{Xw8(_4m{xL7_Sjb@0*26f5}Ww}TY2 z4v?NIaQKvKp&L#E;0$yeLBHoX2&$D`@B#n-VeR-^5B|CVs#5;5RY3Um+-0KCc<`l; zO-Wd}1hAdw_z=Q89F+aY58;dW|GnKBHLY+VkT9KlXQwCc|F-!mr!PGqM^0qThIe`1 z;Lf#e{sV+E?9PSYK>znXx8M@FzhUKfxbudBRifM3NyHjhvNMe~AT|`6HyDDD!r`Ca z4r;PNYrsGjxqGj_4_dY3%^U@au8ED``&a+t=rtdmyIAW!nMT8<`Jd5)>eUo{BmrvW zO8(L8<5|Qt!F7}ulzjeY?Ex5oXIe4*)7s;JZ}^D^58TFO{#S{0B|;{JK< z0h8(jxNO^ji`3t1uUz9YVhd7)(DH=;+wL;}@RBX-r3nClnZGv;pnIv=f65gEchg$K z7Bnk353c|5Ix1il|33?8(@7!gxAhWSa(?{1fZdtmuX;h1JhDdg|8{v;OM}f*v?kBu>U;km%YXMpZJI(pMX1Shz zZ4h+1UvHY7WTDI-@3@*{a+jKoL4wDgWACbg@WCeEfkdFmBsXLGn`dV?ZOLbdarwaa z0h5DqgyzkwCs#2%BC%{Swx-1rgyz-EJ+8W%Yx<14p9%Az@9oVhBaQ{jsZ<)&=fH8; zMWfdt^nFkGRo6E$>oog!h7lJl{on7=Zd7KBbe$BdX0X@7@$c{==d=W;I^OSx;=pkr zi2W7!jt622j-}!A+PFhJM%nzxSQT$}_CM~{SaN?o)Q3m z>Q9;R;s#wcz=39&Bck;$&p2fOKCdx0JvlYR2=6*CS!n7*{~-OAw+Y2nphAGVJ5MXs z^lw)hstdsCsUdbRef;+z{{PqaAhKsaJDLo>`+HRK;OI0^C+$LV|zopLS8K&MoDAIne zFk-uYI3isAxHE_NLnA|$>Ph4M7BeyD+wM9CD%P@IYeqVjFAGmkTc@w4C^(@`e53NK z=ejDH<_^}o>~;r#4;C+9*U96J%pQMdX5q*4)a-0Ijl0{-|*Crh10I?+WCYD34$h0Wn0AWkGf-Bg*(gK4a-t?E`p&%Q}bAaGR+#T z>1Eokw;o`^(A9y_; z?_lASFi_TVNI=YCia;lm`=MFgU5!DbwJOU#gn~9WvNU4ag>)P3c#3tGEydrA0 z%Ofw#U+RO1GFOK43^etLlXe~2W4-;RIM_(YP_KvKl4Ib0&`s4?m!O+8`u4L8c|mE7 zd5!zKPkPd?RTQ3Fw=Xs+$Tyn3^Ud}SM;F^IeNB#en}4_HP5S;CG-UDr_4wT3+RM|k zW2UjWvawnN+whGVhl#U!(T_DQcENkUeYtV7WSTTqhT&3f52`J9Z;%~6SBp46y&PcGcZf9*IQ-|b`jQT0QPV7Q^* zQIn1B;5dWWGcWNs#^ts#-BFY5rvagJA94q7+o}bXWe@XGb6obhP?1^irP)i|_`Tn& z_>2nW(qm6^vY#K`+G#(YGG~A=m7barHAv>I57wD}Eh@D>4{g_0*j{~;SD-i0IQZf! zA;ZU9#S^2bYq@3XzfcW&G%N!Rt{0@~W6}`9!K1Z?ih7xPdEy^Q>W%|WQyvSZT`Lqe zdyy9O!{l3f{()CSW};SNo_1dH7u%nm?05I_9^90$)X$W785fYmovggEtl-i3_03I9 zpS8>RzhjS0Wv|;--d%0Fki}VnNj3xBT{8#D%Hl1n&{2K4C2>9ghAc4I+oZES3Jp(@ z2?}sBdk^qji>sQSERiqf`@T5v(+dd{btWX^n~md^#OKl{g>-lCITiSfRBY))Zh_9R z_-?0Pg12_wG-mSU2@Sfhj1?(?hP};VcP0XiD-JsDUoUs0VHsTOxc}L2hgtT)X3Ars zq-*vi7F-u;`O^A-FYAKZ*tvmJJ~8hDGUIaZECv59G*OI-tI0>Ua_gma(EgaW$af{F zdHqLftL)vLIydvLeQy1u-WB&gGCJKZi0!p#(2uHd=8wEef-h^5BC(+k@BIG_r>$)LSPbl{a79ySOx#|8AG-xOpMs z;oY5R4sO%>dt%QFbU)-t7AfPFrQEXVKc#BvXDbT6>OVn!yONd^BxtG)u6f*`Z|vyz z=BG;6w(LI?FI6*3paF7zVLO^4MIlpMo_>U)b!9!*-5E52`}9~e<*)m3h$S7IH<_~a zGqutQYJuAI@tuSQvhhn&ZfA!)$&@pU?X|H5WFMmrY`E`yeQgq53g~+kcchj}^S@$q zOTa!DTj&qe;`XZlTH)SyPRI21p>CQBBB>?yzyH{UbDok`qlJ+r| z(GW$&gELW`RMeZ}{okxCa-O9-9OAb=5^hqsy@dzoOz`-woxVV!+TTz(-RvI!$yQ*t zd0gTBftz8)pDGJ=G39+3xixxUB|A$o-$DDRhG~TA{Mm#AN5>rrBNZx$UpSBDuAbA+ z_FW#%Z1$fo%<~4s{3fnEJnJ9ZnrPUR;OPF6`QCZ1D4g$QZ3Xn>}4-xxJvC{9w*R^qj+{?kdSb6Cn-&U{41*i+&M&`C$bsgBi++AK{ zUy}EW6>3x}*-zH{HKLaCTIq&A;f zE9_5?%Q~67sMlNjc1v)@$|*WY&iJOR?~?R~G^vtvCN=6-f?r;{D_z|?eGy&ZIvK<= z@%;V0;=Rm!tdb@_#wP-gT>XSC6cXq*^UP{Bqxzo;axL_ehgJm6Q(R2pfAyRH>`Yf1 z#6WN5%sYCvmcJFb?$ zMcYALRMw4$h&RZI#%b#RXw%6l73}$_dJiY zePW75+p9+I3c)O?+aG3x0{^eC_W-B*4gbd@iX%~xB&(7TDxvJCRQ8DM zy;t_0k;+cV2$8+_UM+-!>~oHTlIGS>8@Bh30b>(thy~p#upZj^9d%W(~ z&3u>cEQ7XVl=RERAy@IX#ZjLW9zebj+4o$RUQBlMoLGB!CEaj4yu@V6hnD!=?U0rUjV?r&C^LZS} zyjAgXqHWHQnfPJesVnC_k(!)j!(vVnVwU}P#tDX#P8=`zGHG(u^>iNaE^J6{$|j`_ z*j;BR(ERekVR`n!GiKf-Ftb%g+7k5I34=TJKRr5PmGh;fX}<}`C!IgsNRS6QitqFT zHUovfc3(yAMFw3$=XpuSv<|OZ?Ov9wEFL%R=`pbk+ms)!;w*W||B1e8**7`LSXSrWd6e0(C;s3qM2+X1*=(o78bXyhC<rWhdPaeGfF&;?^gvhV2pRm2VuBKeB{gpb3k|(^GjSTvAJ|Jqx%R-BZ zqWOSV3_J>c{PEjy))i+HY=mm5&h$><4t^JL{U%6@aw&+vz=oKN;aE}69y)yUs0p`f z`ZI(N88tn@#r!N|LjNoE=t=vb3SLHn_KSwyq@7M(?j98TP29j61LL#3j0q<3&#EK)tiKvkB!B2xV=STf-KV-zH$kaZcFz z<-&M#J9MpaE~bpKxz&PKu2YJBt>4r`8e<1o`_6}RS3N>y^{P4hRyu)=|H^Whc6iHU zM=kN!VAJNb9(vDhx;ky?Txs9Mbfea!R2$Bj?12=iR{#0d(Lg3_|99lhWxPEF-MzwR z(Z-Sywm)0KAQSqyT>f$rsWj{%o9+z?Vy||s7UjpBbMWHfyLUOisLrH@QvTpPc2D_u z>eZK5*h(*6Z9ioqe>QIFyp+!A4ylYD{iDZT$()zEdoOxRZ?xaLbN70pqezk?{N0}K z@~N-Wb-?R-=C5lKhHD!O%Lhiv95cNyTLfr*r@v0z^{p#p5D`H(3)`ZxmheZ=g;s_XEGrAqL6@$uVsE`$OLUOw*jzCjh* zlKic%>$OvzhEsvoXy%2buj4LD-AY<(%ZC!K61)d@M3WoLDqj}P?KJp&y(Snf=R3Ij zt6_UU@(Q8kGbM_c`OVWpCPY6?$#%o~iQ0|tPPV^nS&Vyq6qfsq<_`1aw|YD4Y^;yJ z&K(wWk5A7u$uDg=TGkPFwI(uu)GJpALPF;d;W25}G z;b@_z!ER0F8wymj*l4e9t_Fs?{mt;=?ez}@q7=OKKaN+{>ZR^AMsCY5+x0vMntXjj}AB5{;%i@oh zupxQV3y+3k!zq)UJZWdIyZv}Nk*kW@_U`%~RPil@k;gb!>laZ%Php|A>qgOp@7w1T zZ!luZRRlNx50xj&t*S+-*KV)+>Nvl`3A=R%S)(7z=@Er9mQc?1pd7I3Bcc%=N$syL z&G^twyQ_ESd!tMA;&1_y0!iy>5loU5X$;E5#AGaf2Z)o7FR8Rc)M{;C;6+TGe!+;g@XyYr=;rQ$~X3c+CT>^y$qbjEl1_Zga|0 z8Dj+#(fn9xN;KXBg`3G3<0>-K8II&$O7Gj4RVp-DCy}a1&lQ&|Vp>aHFxYg*<%v z6{n@&SkaW^fG*_Tbsnst;+h_6^vVu%b3DX~(em~EjG_u2Ot!~)W_Y8VrXQc(+-UjX zI$|%Acd+Y6-Gf{sym9YN@QQlALwip7E!D0JVrQJfP@|kxLbZ7m=C#daJA>D4LGC8?I3t#tF-vhLyZpbIs1y=8Je(_xUwPo%FzNRt~ZjI$btn(mSF% zMx1`*gl@s|_2hM$X3>WYq4mSHiG-4=aIEL>aV6ZBj5uo{V&=o$ zqqHR!RD3v-jP%6#FBwMNt=R>g_EHsr3xGk=8!(%`&LC_P&XLx8<=3kzIc62N}mp<@GA1I8avYsXx93^VPrmNkLpAf#@J^9H}+z zj6mzoc8@GHmcGP9ZWqhz3||#qx7JH~k6ij1DvXukL#fHePvaAQV&2ONEUs<5D7VVL zS2)76ar6wL>zb#k=g~uw0zn=AR|==hI*SZ!D^Q`N6ml{d2fMOO%PD92z7E);t!d-= zU70U@J$(OKd#N$DMfKOuS?V=pV@OpDPLzO z>_$X}O^I5JM?`X4_-Kb6*xPgbrbxD^OWX@oUSxkjR5ze1M{#j3_cm8=zRpQcCr~jt z@O>_`=EA42aRrB;DHR!b-8|!8jcAtGYfsE_*ekS;cL{b2E zW2nf(S=i-!;6%sHx^#>oGa|Xvq|-W%&$Nz(#JlXS9)0J_Wry}xqMjY0D-@jPJHLo5 zrpz2G37Twv@O!rW7)8-?2Q73!xV@Tk_g6PAI$XG7v4DH-$; zkoiUy_`olcJFku=954dZCAV)0Dw2~(9*4W0g@sW6HZqn94zMP;62jxpNwhYfl3>|{ zTT>e&EDTl=aD&}w&s8u);Qim>T}JvWV>&R1MLZ+FG|NMHr+)f%D*x6KEXE;qVkvj_ z#qq!I{GVam3WP7Y<#h0s1j~82D{`ocY|p+qK(vVP5*-&JjQ9C03i!pJ8w0Xb-o^n$ zA7Mb9I5*kSG80;stJ`E0@0_ffZ{y0Jk0o!oIYz%Or%p&s@$l&dWHCOES5xNPu|EdW zuxv%fmV&qx5fLw9WBHekcJ7l09##c(wIM?1s9*SvHrI`!8G7E1o$Vqs7dcVlo>7~1 zasSM`N4GRC(ru zn@GGNE%U5kFnO7xMz=S+}_ zSU%yUBv)>-Ga9kOQ97J@6fX_)C4#nBW%w;ne?=I_3#J^}K~`!DVjnLbccO?grMuoA zxc7(j#B)p0XBK8Qy+GVmdS9(R@7*z{|1$ahTVx+UGd!oZ|CX>Jf9cOJZi60a zoZ=w)Gi_w&C)l>3mbH}6tn=ZMBlLB>W0wb4yTo_PoMw80e({X(?f|Q5Njp#Nk=*kx zCdx^p=SPJ4d8SB3`lqYyImtI{Itcmo+1K)_Y!Kuy8tQA6H}^-QOd{!hFZsHCDgZG$ z;mL*o3{WVZyf_3n$eMmgcF%v)9W2n%d3xP#Y0MGgNxr@qAWx7>5Wq3xI$v@`=WCQU z+t#lQ(C{sz!!*mx4Ip1u3TzGN#@+2jn%o_zkj6m`cIh5K)T2wkOQ!fayTd!h|Ugq#~*}3RZIVRb{T<{$BlKP;ZLPXQlh=<>_gEMZF%tQ<_nUEcpS&Hrs9Pa7JNM zbzS###_NxofK;7+z*BsE*fAr{0&(1k8SvGmybVA{?rl#cS@Eg%2k@vD0qp|@ik^*u zO#Z0V?q}{jo{LII3!2nlJ57IFYEgO`#2KPiFUm8t1lkpsv%eHAhB2u5Nhn~V4)o%l_q1@ zeJXLF#6*Tp705W+0fRj6BWBTazV`cbt~@~23jk`LF!q>gzJajdxLaTn2r#Z|$Wy&} zMNnK_b$M~2c2OY66F|9TTDH}nW4N_U)%Coc1H6|%Q1A{tP}ea5Fow|7nSMu@C6^zq z$FzzKgZc|}F#V;#S-wyQ&{;H4XPOsFRe-t-Uptr!1+x{l*cxzGfL8)PU)qI+8UyUS z5Y`dcH8SX3c^kJ%XnC}UfW_QyFS#6bPFgctIsHwvZO}OIrh3)mj(FSj9^@$6zk1ilW`>1MN7b}3b0}RKY>^`r8REHW`IbO;;e_j9eFuopnoVB429Xh$EFQoYmII-4_43

SDaq3$H#9e41_LUdv?x~1FiJN{@jfZ>#`EnqRg>pejXgx&GgiZT<(H#d3# z(_|!wb&AG%F4ixYZB_zjeLOKY>2*Xn&7GOkr0(R5fL(F<7PG)eP^_nU!ZB8hM_cLC zJBRDaDHmETK& zYI>hoec>|SWXWr5*~FKF(mACKz5#PU;y2Ld3F52e+kmEaS^yd)%AD=lmbrr3&Nz;@ zD%A*~elt!j-P1CNU*Cm&0Fmd|+3PT{(WY`Li{9?^oTgB{R|}Kpr4jn@Z<|5X zZ7U1~HEF8Vb$*#IudP3$+2890m7D3;$RglabkpXb-QT;P(%zkB)8X@gg(s{ItTc|b zZ&~pv2X1}qJ)t0C0Idd`-VcWGHcZQZmE|J6%PEx-VVc9?q~d1}!>FY<2keQfXFR`o zH_L^zuN@a(^#l11-3Qfd<B+d8fD3i#vX$xtuZIJ%^)lAcrY~E`W3I?ekE(HbIn;mUt5{&=mI+OmohFYbR z9Cn*Cy}LpR?Kb@5^4~I>V~MeVh#v`y(?=I7%CcsE5MK)kfPZfSbPk;+h;S>&XNe%Y zRh!Ta-kyV>f|)#80a_#5dUA5qg}mM`dC^rZHuv^2S)g2#UCN0rx1lG;NizdrE;!#h z%Iupv$|@`=k#eO~P)hY9Cxgv?lT2d%oJ(tl{w@Ewn6O#J$aie62m_iDj;tPs8Zh_B zmz^FCs&2J?ygP^6c;PXKi z^wg_X`n~^1C2kEx5lO=0;Iq2tnm*@f?Maz7lcAVq068WE_KsDu zv#h~yfi2}9CCXdWO z1X?3H0M#@lnG8XN7G~dJQh|8Z^f5WUq}B4yA4gC$hb)>Cj|N#IYRa_vZCf);{NhtG z$K)WWPyu?1GLDa8Xa9Avsi&k_qIDok>0D)?iZlcf@Y|0mQHHNGBjqa4;_}iU(jzyx zYy0N87;_S0P?puXvy8H}{I!>0gZPD(QLsfE?H{f2HI$x@lp_pT;0FPljFlB3jf^MH z(C|fE8h>K_zHAqU=Ttg0P^f1I2o~qk?;~C-$Ve%HpcwL3ndu)i)fi%{GsV7k2ufBg z?R=NH_~`oRV&2lg^PXd5$p=_5qZK0K-x!Quo)H6UcnPI^b>9X(#z8$R8#~0BU`omH zu=}_I@=VVq0n!s_!AmbA`*<?ntEbUpqq2FirEq>b*e7Ps||_|LM_f1(UC ziijARvdGYpoXGA(n3Ve_EujKK#$n!(1Ki92{`)y+99B>~Ec9)OYW6h6A3ye$D>c|^Ai&WEl z7UDO!MMd>Mv63|xPRD-1!yZ$E(m*fDiuYP`aSN*sWpv*uT`5abW{qu_&z6#*HF>+; zSk|#7X<9N9lD_#p9j~I|U2Yb46qlQ^q-@alOHu&Iq}u%s^}wIrR&z5ug2UXx%Ilk< z7fwy<@l2OrHkvmW0Dv_fvv;!?HDBcNrXMop<2;~xD)+ML%|tk84X*Yjc@@77In?VD z#us!E9U&KNs&2_q6d)BAjIpgnmKc9IUiwL}L0ZxC*rCqej=I3M-ct{XgiJe5t$#a$ z>&g)KH)-Nnb_?Fz%D4q4M1-PV)0(H0?=>c38MMNi>)(>haE?__@j%z+zyc0KO8Bq4 zR?hO1h1SV@CTokv<@{QAiAvDRc>YX!t67Gyc5op&kaYy;OM*R2`(3zfUhQUL zXYWtrAYWK%5whkYZ4ck4m@;#o4=3~1&1s>T-f#YzTwP7a zFu0DFVez;By6K=u$od*+EjDW%|9s!T1ef)NeAhW@?mxa|%vTxV0NN!??7EicW5&a8 zrsN||A*V&sRZ=X_ z8^LN0_Oc^7K>U^E(frVry{!i-IC)OTpIEkmz{CbmYQTT@>_D!`O&=Utqrlz4s#)D1nDE6q0TC z+QaNsRP4d!+Y!&4X$U8`*PC7(Dr|x@cV(cDw_7( zz&-B(oDjB*RDR38B5-1ExqQxMBqlq5Q8D!vxmKf3J;5ESZu9gSeFkzKaDDo45loId z+Ryi;{9BRXM;*%}08Xa_0nD@le97~p4qF?` zEky=3kQ7dlXFzucL3)B`RVDkoQAKg=&N5lwZODMe& ztk2CBQUn@~KBM^h(;e?u7(#n?lcic+40PY@$0C1EGW@s|CIRl4Y^kabu_HO!&Q*BWW1^U~OR z0XSe?vEa~>qj(-hQczbt@`t`UDDaQ=xV^}A^R`c_=RN%Pp(^lb(kMkGyDsZy=fEEraJR>Sa(JbEJ45@)|&X=oC6lq_(7pNU7p-8-ER=1ZzLMwXgmq!vb zOk-Ac7kV}&XwGR`LaOjN?%x%#N}i`(-NWQDo#M7`=K02$nA zncwRy)`ptCeT*QK#joD13;Bw*0HCrH}vJ5$g+&6Rfome@5?yD!d5FOsxhynHONM7){0 z2sX|zDezVn11kZP0jve^j`^DX+&ZLir}bnVt=ZXsmaY=_Ta-1_s8j%GTnlLJqVw}% zU<_=etS~TGX!+(}Q}MvmsRM;;*C5?`u;2ik>d}FD_HJUPJ8b~ckYi=6liYmxo<9jHT~XVr4hKcsYsRg3!W zoBi0tf2y)^Qyz&q3m)7OpH81*ZsKtf7|m}wYLyntMFK} z)T(wc`7{vqoYfB$4;?MX%%@RoczcP?-n13Q&)RFx`tLh*h+$I8px7i<%&*D+0PzUj z5YU+M!$BorJ3PJJ?Dm#4(3)xbe9qH)E5FfqRpT9o9$}s-lAphIa$S>(N3?DjjYVTqxi29~V=mY#;4SECaQGvdD%O5pL z0x5^ZbC&XD4d;3@sR)AUPXz>TvTP9pPiQ4@2MJxb@FKl_m9I=?wtMCSm$r zgz|Tn;e$jZA1*DyCquxY-x}C04H`~-mhp=Wv;SxH!XZtv7w7QAqs(y83Su|+KM`uG z@DTb$Lybe~|Ia5j0SQUINqfv**pjB2S@7lo|0Tp8cbgjUQWJn+;e+S7`>rsRM@arz zDhG(zi!6FCSoNemvl*^Bc8rAjDX~AqFdxDuhfA<)cor~l%jX2FrpAb@e^nszEb-o9 z9(Lx&i^7nMKY~S@DeY^{hFvpb>=(%1{3D#L3zs3RD-7KlMYmUbr>=R!8KLG8{1JEH zz$t@!w{EIt$}>X56(fEZs7*|CYMhFp?r|e51b>!#nb;!(s+RR(Y$BNjux00Cod}LC2fb5w;6-Vd>-aS2_!+7Psf|9~NAIcI(vbxp{Qo{HGh~M#k9~~KODsHHF z{Ecri@Y&X{_Q>GjaAjAaym*L+g zjrmAj^+apb$IaE56DrVNkV+HEI&nZeAe`B=`8|txk1XJhsgZ~#m!v?1%bEUv4XYPt z-|{}&D_*yyM?p{2`R6^8)3ltbm%XDdKk7&j?Ai;dVt~b2*|wz|_6t*gAm~NeRAMJh zOF>JJSRBx6yMVQU-aDZj*TFi~3^~a?59Gf`9K*KMMOYJdH{YDxx8na(Q+en-dPuEE z|0XQE1i~Va1jn1^f7ZQ|(~Wz^p7dzS0ro#*3iD!Px0ox0q;gFTkgc`)5 zN{$;v90&brsJpHEB!dvuX4iqU91qc|Avr;;F%U?t=x z<|51mrWa8p=`{%X3y;4-7#{D@D_F~L8gAWih9dp}TfG=ctieXYP5NXMm@F#|~ ze?^@DJ$h^O>mi-NWD!=GbFXzWeQ1hU3!mCuM?0w(96x9a27{8zW&qMvpNUr5h{A}l~ z1U`$;m$*Hl@;P6-+5r#2X15GRRT?6aos-?_kAcQ$uHjldI2yuG#|BkZT~M5C02SB4 z9I~+T4Fb|1{v(4XF5-C93tl}jOoMkVCfWvi^mVu`-=Bh&^Zi(N&*EN$xoj~tpd6Am z)rI6%qo2V!F+6zz0z{YJD_{EOKp125hm;c;yN_NLCn%{dQ zH+F*xl6G~hZ>L^IA{o4IbN^)wF1>2{uE{??Vk_#^Z+beUscpwQzemf_``~-Wg@gk9 zZK<`b<0ZdRb1lyQCHIlzV-pyy?b|ld(vW&{Kf^)l-$%`2EJNwM@Q)!p1(|YT-ec^y zQowGYap9Ev83rkKyPEhXe_4=Y+@6U^`r`!pI7WhrJ7eV-?VaD?GzDNFB9)6Xe$#wmDeFE@Yp`ZsieQB z`w<`rDpro$OVaG+_6;(7&W}PcETpZQM4s=uIn@>B*!j zf0|Vj9?AW)W@U-0Wxn>D^d)tCE^Zb`3-SUWvs1o@#u$)|yR`k8E_2BHqxU;=pGk7m z2T_lU%zyiRlErt+j#^#dt4qqC)D#lWC^sf9rVimYRyye0AAUbtTfcw)mT+y7$vyZin%TnJyUWJYA2|kZ|jP6N#IVtF81DIsUtW)_+|W?4|;%aY(4x&cvWz zzFWRf#Z(hfEvtfUu2yvkQoE|#Wf5x@^mD2*V$|y+4`uIsd(&bF7f+Xdm{ng43SY@fgu;-K8d$&myvopv33YzVk8U8C|w~SgdUjLA#=R*%t&~a*M8ug4=v*|@YX}p3xq!9n3)3eB0BSF;ZX+Emx z{bneWksgjP@Ej_koSEF3NZpF1N+nt==EoK0Q!8Vj&+1#Wiq)poL=gXR{A9Od`6wSjcj33K}X<1rTf(D0rzRe+Qo2T>yLO`h0VG; zzIA&-=kh5z89DGu#?_7V@iX;o_lE+$p2K!@@%e6>Oq}K~?y9gY|KhDzByz)1b6s6} zNNRo9DA~KAUu?9z-BK-2KmJ&bnZ_>GdFYYrj!h6kV^NJ;rz&yiVZ(H4$fcJ0iQ&=) zho9%UoaeZY{>Y}&b?SPu;Mu}m`9{!qJWe6e{#I*@(8BOP(HXB(B^=j0eJ40ErEV$p zEAJ6=YiO&53D31nN*?3h<5z-IemP8(gwkIf0-C9O4WxAb(i-V-YOz^K`SePFocFq@)@6^E4=hk^q@IV>3yPfR zbY4!&R#dl(7kMNK)h6}i#WnQ|*erGw%kWlzDk$7~ONfhk&3U)G*!Pz*U54m-+hw{o zla5E#`P!A8Nm-jkw2jr>H*lv=?`lmN#yevtQlDWLYd=co8LpgMVEXy>$T>#GV5-tK z*VoVDr=CS}U#%%wX_vzF&)A^>eU~q{-D1eu@4H3fX|Q@~Zn!FnVVxgS?kY|}GPFsc z=n}C(TJ@JPC#R*mSpYabViIbY%4$tC35cj;b{ z!5H$=h1Ifa8LsC(no8t}Rn)&?HpErofc3~7s0qWoTX@Ced)WW$BRB`{@D42ge6wvf zQxZFMUXd&4M>*Q{k%8+Pz&}{E!hWu!LSG-k1TzStx8)LQC(;n#POnYRn1;=|$NgOJ z95CxM_bguHbCHfZOCD~5y!+s8(vWg@Q<61#;zzz@52L4%+-vjIe1C6;=g(4yv10mm zw)a4^xGyClJ#b;63_P{g6}RBu_wgjJ9}9?4NKxD zyjdBiK#gb>_gK14Pn-JkXQ}m*62(90E%M;6yv0{6iYlmWOl7>`(&4jZoUipF8)a#d zZ2dzrXkJ~{!pN-AB%~$$nTFWutk5>qrOY@t|8vx2=TDZfxirQ&&sOQDN{37?M8pig z^2lr8_epyaHv3i7GKmd5-RZW4k9y^=GaGEZlWa|N;;)5g_pO|hM!kI&J$&ktX3<*; z6wOc0`F?9YoJCe#DkFoywM2&X?v6f9TCNM%V}wOeO$hZ|`%Y_>SN2;i(fDmDX>QL7 zYy5bZIJ;1En3>pGGc~K$EXvXBZ?kI2pta%=@^nI@M|naDq&U%!xi3d^+kL*jFghro zJPCQHMWDzlHWZ{QA)FR`MvfZ!HFhu;Q(THkws{onVylrrT)3A^t#xwd8~o0MAa2lU z=nz;OqUT%TUI?Zl_K!6kuox|)r#SiW9aFvb+ea!5bptMA`ht(8Qxfogn@)3TdzkfH zRKa7+M>#$`VzPb8bei2xU zT~f+SV!+rPW%_~F0YLf{@pL^s!ore4`>q6SrITu@@LKT+Q7&%FK?i2$8*A+rguw^u zsD{UfPn>E0SmRWoz5dFzs{v9}W?4(}U4&XLwW{BkckZu|d-6)61PiBhs`GCPXY6An zF8Bm{-uBpQ?RKhyz9cgpL&=I5-a7MAI$nFi4QwoXM8aA|x*3CMSDB}gO;y2b!lq7* zX0$X@F{erU!;cH1hLXZTj{c%3w1C@LQr*bY|aqJX>A9{!J}Dn;h=VrFXHSsMV;bU*IjUXv^GMvp!j<(^AkGgWRmNu;HWQ zcHp~fc9b5PIGD&gU80J++Rx23m~z0b{dKLjXql?VPF9rdjS9T6iedp(*e9_&>+=Tg zTw-H8&&pLoqLv zA&T-5a&F_k(kqQ{`VZk8r0Bohy}F1fX%n<=b7j2wxey{J!F<5KGJ zAg8ME_@QunyLuZ0dX>LtlHcmWsgmuVbAZivT)Z8<*lA$t{*$T<@t7GQzvmk=b+ktg#H6o9#dGi)R$KR9an9yuD{vErh-1 zV<$SJK60y2oB?D`(O2oiPUa|$@Z7?Rklys#ZY}NZUo8qrWsjx9Yfx749CU>4E@)#N zbQE35?V>|_+R|Nn_?)fq_Z^X9f{x;6QMaqCT0flQP0g7^Jo+4Ty~}l8@mtA~9%RK7 zO$ljzobAUAJv$M@=ZW{dCZTr_8FK^B&#}t^r`IuyjkQ!yjoaL+C)$#*J;OJrw#2aEZ@VdW!(& z859;1@b~e`!U4X_y|3e%h2W34XHJ#9aa#B}q1MnhN;^_E8Nfv=DJ36%f#mn3fO`9H>$ zirJSbU{v(6-;a^K1841&J;&*PoE1oxQeeZm>bWFZC$SChCc#_!*L7t%3tsiv60I!N z*gg09h%A7Ll;G<_t7!g7Mx{{X`Q&6L?tBaGrWm_VB69H&>Xo`C>dFa`hb!GV9ga*O z%-b1$Nu$@mml9=&>wl;A_v#x{Cqf$Mt`2B73M9J_RGaL^0sgx-Q-qf*|>SofJDA9dN(EQ?w88m4{bc<*FFvxZi~E;t+RGBM zkq}>-wL$;wUIvr6{lY7nDfJtlSi)36syEu^t)k+8Lge`07n4wrO(a|ylLm~NC*YYQ zjL-@;5G+u!>s9;fB9Ze=?UlhoAvqVIeS35LC_2IVU{JtWWsuw1+MxS?E7bws(NH3B z#D>i7YMes9a&>cf=%ZCg{3gY;55Q**vt_#-`xM9SO$QNF*W*VHa()D!`z z+;tzTLgtpTWVCGSWadal{$Cfl-SvBIN9BC-ftXl{4@II|=R4Es9%{aG#zv?9FJ(Tz`g42Ou?qzu zL+wOp|7QmdQCRDy3J37DvKSMz!lp!Ce*MUw%?+3tv}n`zKWk9eKk&m7@tuv2tJ2aC z9_5BWFxgpnWuUdMD_fU)-w_Kv{=i72I#$l;ultt5C9(`YLT|f0ie8s>tSbg^)i%(P zX)ix5lOT{^4@l5G(%$xD@kUpctn5M(Nyc^E+Xz6GOzGm5rwasGwVc*@^Z&|`G?fyk zG+ZuTGKzdQN;{eJo6PDxhn8^|?@JQB+ZQ9t z9~p3=Y76b+JO3xdm2m^Yxf*(9F?ar2P5-;2Q~y6>kNf}c=L>nA;$EjrLxrl_;c^J{H+2>{eb@iQJN^C literal 0 HcmV?d00001 From 2f5ba6937caa0226e4cfd6a6cb1f91e3c19aa529 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 10 May 2022 07:41:37 +0000 Subject: [PATCH 2/3] unify python style --- docs/design/phi/design.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/design/phi/design.md b/docs/design/phi/design.md index efb363f8ce2..2e35d07a78d 100644 --- a/docs/design/phi/design.md +++ b/docs/design/phi/design.md @@ -1,5 +1,5 @@ -飞桨高可复用算子库 PHI (Paddle HIgh reusability operator library),或者我们也成为函数式算子库,支持组合式算子功能复用、Primitive算子内核复用、插件式硬件加速库复用。针对飞桨框架原算子库存在的算子接口不清晰、算子复用成本较高、调用性能不够快的问题,我们重构了飞桨框架的算子库,设计了灵活、高效的函数式算子库 Phi,可以通过对函数式算子接口组合调用的方式实现新算子。新算子库提供了 200 余个跟 python 开发接口保持一致的 C++ 运算类 API,以及近500个可供组合调用的前、反向函数式算子内核 Kernel,可大幅降低框架原生算子和自定义算子的开发成本。新算子库支持Primitive API方式开发算子内核,可支持不同硬件(比如GPU和XPU)的算子内核复用。新算子库支持以插件方式接入硬件(比如NPU)的加速库,实现低成本复用硬件加速库。 +飞桨高可复用算子库 PHI (Paddle HIgh reusability operator library),或者我们也成为函数式算子库,支持组合式算子功能复用、Primitive算子内核复用、插件式硬件加速库复用。针对飞桨框架原算子库存在的算子接口不清晰、算子复用成本较高、调用性能不够快的问题,我们重构了飞桨框架的算子库,设计了灵活、高效的函数式算子库 Phi,可以通过对函数式算子接口组合调用的方式实现新算子。新算子库提供了 200 余个跟 Python 开发接口保持一致的 C++ 运算类 API,以及近500个可供组合调用的前、反向函数式算子内核 Kernel,可大幅降低框架原生算子和自定义算子的开发成本。新算子库支持Primitive API方式开发算子内核,可支持不同硬件(比如GPU和XPU)的算子内核复用。新算子库支持以插件方式接入硬件(比如NPU)的加速库,实现低成本复用硬件加速库。 > 本文档撰写于phi架构基本成型之时(2022年2月),仅代表该时间点的基本设计形态,可能和最新形态有细微差别;此外,在2.3版本发布的phi算子库仍然处于初期形态,后续仍然需要持续建设并完善,设计上也有可能调整。 @@ -64,9 +64,9 @@ Paddle 2.0发布之后,多次收到内外部用户反馈动态图在小模型C ### 1.1.6 Op及Kernel参数规范化 -python 2.0 API项目规范了Paddle Python端API的参数列表,使其变得简洁、易用,但是限于当时的情况,Op层面的参数列表并没有规范化,因此会有不少早期开发的算子和Python API参数相差较多,例如conv op这种,python API仅有7个参数,但C++ Op却有30+参数的分裂情况,而API和Op本质上是同一层概念,都是对一个运算的描述,参数应该是一致的。推理为了解决此问题,推动了算子定义增强项目,为部分不需要的参数添加了AsExtra以及AsQuant的声明,但并未从根本上解决问题,这也是phi算子库构建希望重点去解决的。 +Python 2.0 API项目规范了Paddle Python端API的参数列表,使其变得简洁、易用,但是限于当时的情况,Op层面的参数列表并没有规范化,因此会有不少早期开发的算子和Python API参数相差较多,例如conv op这种,Python API仅有7个参数,但C++ Op却有30+参数的分裂情况,而API和Op本质上是同一层概念,都是对一个运算的描述,参数应该是一致的。推理为了解决此问题,推动了算子定义增强项目,为部分不需要的参数添加了AsExtra以及AsQuant的声明,但并未从根本上解决问题,这也是phi算子库构建希望重点去解决的。 -我们希望能做到,Python API -> Op(C++ API) -> Kernel API三层参数一致,使整体架构清晰,每一层复用也清晰,一套python官方文档,基本能够满足三层API的共同参考需求,不再着重维护额外的文档体系,降低维护成本。 +我们希望能做到,Python API -> Op(C++ API) -> Kernel API三层参数一致,使整体架构清晰,每一层复用也清晰,一套Python官方文档,基本能够满足三层API的共同参考需求,不再着重维护额外的文档体系,降低维护成本。 ## 1.2 目标及范围 @@ -142,7 +142,7 @@ paddle/phi 部分目录结构说明: - `api`:API模块,面向外部用户 - - 直接使用类python的C++ Tensor计算 API,和Python端形式高度一致 + - 直接使用类Python的C++ Tensor计算 API,和Python端形式高度一致 - 该部分可能反向依赖框架的DeviceContextPool等实现,所以单独管理 - 在该类API上,训练和预测也可能是不同的 - `common`:phi内部及phi api目录均要使用的数据结构,这些数据结构既不属于phi core,也不属于api目录 @@ -529,8 +529,8 @@ Tensor scale(const Tensor& x, C++ API的自动生成是通过解析Yaml配置文件来进行生成的,Yaml配置文件分为: - - 前向API配置文件(`python/paddle/utils/code_gen/api.yaml`,解析后生成代码文件为`paddle/phi/api/include/api.h`和`paddle/phi/api/lib/api.cc`) - - 反向API配置文件(`python/paddle/utils/code_gen/backward.yaml`,解析后生成的代码文件为`paddle/phi/api/backward/backward_api.h`和`paddle/phi/api/lib/backward_api.cc`)。 + - 前向API配置文件(`Python/paddle/utils/code_gen/api.yaml`,解析后生成代码文件为`paddle/phi/api/include/api.h`和`paddle/phi/api/lib/api.cc`) + - 反向API配置文件(`Python/paddle/utils/code_gen/backward.yaml`,解析后生成的代码文件为`paddle/phi/api/backward/backward_api.h`和`paddle/phi/api/lib/backward_api.cc`)。 C++ API生成的关键在于Yaml文件的配置,以matmul为例,其前向和反向的配置文件如下: @@ -1590,7 +1590,7 @@ class KernelContext { 目前,phi算子库仍然处在Kernel体系的建设阶段,Kernel尚未完全迁移,且仍然存在诸多完善点,但将来phi算子库会更好地将“算子”的概念纳入进来,这还需要比较长的时间和比较大的人力投入。最后,从“产品”的角度介绍一下phi后续对于算子开发范式的规划,也能够让开发者更容易理解 “为什么要做算子库重构?” 这件事。 -### 2.5.2 原算子开发范式 +### 2.5.1 原算子开发范式 我们应该如何描述“框架算子”这个概念? @@ -1684,14 +1684,14 @@ REGISTER_OP_VERSION 说回开始的两段式结构,”1.算子描述;2.算子执行“,分这两段是必要的,也是业界普遍的做法,我们不需要再分第三段了,但paddle目前存在第三段,算子描述分了两段进行,并且这两段还不一致,即PythonaAPI和Op。 -API和Op都是对算子运行行为的概要描述,本质上只是同一段内容的不同展现形式,比如[python dot API](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/python/paddle/tensor/linalg.py#L993)和[DotOpMaker](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/paddle/fluid/operators/dot_op.cc#L35),就是告诉别人“它叫什么,参数都是什么”。 +API和Op都是对算子运行行为的概要描述,本质上只是同一段内容的不同展现形式,比如[Python dot API](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/Python/paddle/tensor/linalg.py#L993)和[DotOpMaker](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/paddle/fluid/operators/dot_op.cc#L35),就是告诉别人“它叫什么,参数都是什么”。 咱们对同一个东西的描述,分两个地方写,还写得不一样,这是很令人费解的。就好像你介绍一个人,在学校你说他叫”张三“,在公司你说他叫”张三丰“,有相像之处,但又不是一个意思。 对于一个算子,它的输入、输出应该在各个场景下都是一致的,如果不一致,那本质上就不是一个算子。 -比如,conv2d的api和op,[Python conv2d API](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/python/paddle/nn/functional/conv.py#L416),很简单,8个输入参数;但是对应的[conv2d op](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/paddle/fluid/operators/conv_op.cc#L259),有**32个**输入参数,让人摸不着头脑。 +比如,conv2d的api和op,[Python conv2d API](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/Python/paddle/nn/functional/conv.py#L416),很简单,8个输入参数;但是对应的[conv2d op](https://github.com/PaddlePaddle/Paddle/blob/c6f49f0b9f189e043b458348d7fd1468e2645621/paddle/fluid/operators/conv_op.cc#L259),有**32个**输入参数,让人摸不着头脑。 开发者也会很困惑,我开发op的时候,API和Op不是一个东西吗,我应该写得一样呢?还是不一样? @@ -1704,7 +1704,7 @@ API和Op都是对算子运行行为的概要描述,本质上只是同一段内 对于这个问题的解决,我们的方向是很明确的,就是**Op层描述向API层靠拢,因为API层的定义是经过2.0 API项目仔细设计过的**。 -### 2.5.1 新算子开发范式:完形填空 + 拼积木 +### 2.5.2 新算子开发范式:完形填空 + 拼积木 phi期望的Op开发方式:**“完形填空”式算子描述实现 + “堆积木”式算子执行实现** From fd94966dd724080c305840de0ae454c1ce16a293 Mon Sep 17 00:00:00 2001 From: Chen Long <1300851984@qq.com> Date: Tue, 10 May 2022 22:45:44 +0800 Subject: [PATCH 3/3] Update design.md --- docs/design/phi/design.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/design/phi/design.md b/docs/design/phi/design.md index 2e35d07a78d..c67d07f20df 100644 --- a/docs/design/phi/design.md +++ b/docs/design/phi/design.md @@ -1,5 +1,5 @@ -飞桨高可复用算子库 PHI (Paddle HIgh reusability operator library),或者我们也成为函数式算子库,支持组合式算子功能复用、Primitive算子内核复用、插件式硬件加速库复用。针对飞桨框架原算子库存在的算子接口不清晰、算子复用成本较高、调用性能不够快的问题,我们重构了飞桨框架的算子库,设计了灵活、高效的函数式算子库 Phi,可以通过对函数式算子接口组合调用的方式实现新算子。新算子库提供了 200 余个跟 Python 开发接口保持一致的 C++ 运算类 API,以及近500个可供组合调用的前、反向函数式算子内核 Kernel,可大幅降低框架原生算子和自定义算子的开发成本。新算子库支持Primitive API方式开发算子内核,可支持不同硬件(比如GPU和XPU)的算子内核复用。新算子库支持以插件方式接入硬件(比如NPU)的加速库,实现低成本复用硬件加速库。 +飞桨高可复用算子库 PHI (Paddle HIgh reusability operator library),或者我们也称为函数式算子库,支持组合式算子功能复用、Primitive算子内核复用、插件式硬件加速库复用。针对飞桨框架原算子库存在的算子接口不清晰、算子复用成本较高、调用性能不够快的问题,我们重构了飞桨框架的算子库,设计了灵活、高效的函数式算子库 Phi,可以通过对函数式算子接口组合调用的方式实现新算子。新算子库提供了 200 余个跟 Python 开发接口保持一致的 C++ 运算类 API,以及近500个可供组合调用的前、反向函数式算子内核 Kernel,可大幅降低框架原生算子和自定义算子的开发成本。新算子库支持Primitive API方式开发算子内核,可支持不同硬件(比如GPU和XPU)的算子内核复用。新算子库支持以插件方式接入硬件(比如NPU)的加速库,实现低成本复用硬件加速库。 > 本文档撰写于phi架构基本成型之时(2022年2月),仅代表该时间点的基本设计形态,可能和最新形态有细微差别;此外,在2.3版本发布的phi算子库仍然处于初期形态,后续仍然需要持续建设并完善,设计上也有可能调整。 @@ -50,7 +50,7 @@ Paddle 2.0发布之后,多次收到内外部用户反馈动态图在小模型C ### 1.1.3 自定义算子的易用性提升需求 -2021年初上线的新自定义C++外部算子体系,在接口与函数编写的层面上,用法已经比较直观了,但是因为我们缺少基本运算的C++ API体系,事实上,在实现具体的自定义Op运算逻辑时,一些基础的加减乘除及矩阵运算都仍然需要重新实现一遍,不能复用Paddle已有的、经过优化的基础运算,因此一些复杂运算的外部开发成本仍然是比较高的。而要想复用Paddle内部的基础运算,有赖于的Op体系升级为函数式,并整理对应的C++ API体系才能解决。 +2021年初上线的新自定义C++外部算子体系,在接口与函数编写的层面上,用法已经比较直观了,但是因为我们缺少基本运算的C++ API体系,事实上,在实现具体的自定义Op运算逻辑时,一些基础的加减乘除及矩阵运算都仍然需要重新实现一遍,不能复用Paddle已有的、经过优化的基础运算,因此一些复杂运算的外部开发成本仍然是比较高的。而要想复用Paddle内部的基础运算,依赖于的Op体系升级为函数式,并整理对应的C++ API体系才能解决。 ### 1.1.4 共建训推一体算子库,降低推理算子维护成本 @@ -146,7 +146,7 @@ paddle/phi - 该部分可能反向依赖框架的DeviceContextPool等实现,所以单独管理 - 在该类API上,训练和预测也可能是不同的 - `common`:phi内部及phi api目录均要使用的数据结构,这些数据结构既不属于phi core,也不属于api目录 -- `core`:phi内部会有一些自己需要的,公用的模块实现,比如基础DenseTensor、,kernel注册及管理模块 +- `core`:phi内部会有一些自己需要的,公用的模块实现,比如基础DenseTensor、kernel注册及管理模块 - `backends`:backends中组织后续需要为各个后端的新增的数据结构,比如CPUContext、GPUContext等 - core中放置对于算子库来讲通用的基础数据结构,而特定后端的专用数据结构不放在core中,且依赖关系严格保证backends依赖core,但core不能依赖backends - 例1:Context如果有基类,则在core中,而继承的CPUContext在backends/cpu中,GPUContext在baackends/gpu中 @@ -626,9 +626,9 @@ void Scale(const Context& dev_ctx, - 基本流程如下图: - ![图片](http://bos.bj.bce-internal.sdns.baidu.com/agroup-bos-bj/bj-2aafdb051eaea7120bdf9604eb738029dcd3162a) - 这种方式存在的性能问题已经被torch自身认识到,所以torch也在做算子库重构,但是积重难返,他们重构也并未对此问题从根本上解决,只是减少了一些redispatch的层数,我们不能一味模仿竞品自身都认为有问题的设计 -- 为什么第一个参数需要是DeviceContext?为什么不能不传? +>- 为什么第一个参数需要是DeviceContext?为什么不能不传? - phi kernel要求是纯函数形式,即函数内使用的变量均通过参数传入,或者在函数内部创建,不允许在函数内部使用全局单例,为了适配多样的kernel需求,像DeviceContext这种存储上下文信息的参数是必要的 -- 为什么需要两个模板参数? +>- 为什么需要两个模板参数? - 为了方便设备无关kernel的复用,假如我们要实现一个傅里叶变换fft kernel,假设这个kernel能够使用基础kernel组合得出, #### 2.3.4.3 Kernel实现