Skip to content

Commit

Permalink
noDA performing very well
Browse files Browse the repository at this point in the history
  • Loading branch information
ShrihanSolo committed Aug 7, 2024
1 parent 7dc686f commit d5413a7
Show file tree
Hide file tree
Showing 25 changed files with 7,511 additions and 289 deletions.
9 changes: 5 additions & 4 deletions src/training/MVE/MVE_SL_DA_v7.ipynb

Large diffs are not rendered by default.

11 changes: 2 additions & 9 deletions src/training/MVE/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,16 @@ MVE_SL_DA_v6:
batch_size: 256
comment: go crazy w/ batch size with LR expected to make gradient stabler. Didn't learn variance so well.

MVE_SL_DA_v7:
DA_weight: 1.4 -> 1.0
beta: 1.0 -> 0.0
lr: 3e-5
batch_size: 64
epochs: 250
comment: go crazy w/ batch size with LR expected to make gradient stabler.

MVE_SL_DA_v8:
MVE_SL_DA_v7:
DA_weight: 1.4 -> 1.0
beta: 1.0 -> 0.0
lr: 3e-5
batch_size: 64
epochs: 250
comment: Combine knowledge for v1 NN model.

MVE_SL_DA_v9:
MVE_SL_DA_v8:
DA_weight: 1.4 -> 1.0
beta: 1.0 -> 0.0
lr: 3e-5
Expand Down
560 changes: 284 additions & 276 deletions src/training/MVE/VisualizeModel.ipynb

Large diffs are not rendered by default.

3,404 changes: 3,404 additions & 0 deletions src/training/MVENoDA/MVE_SL_noDA_test.ipynb

Large diffs are not rendered by default.

3,021 changes: 3,021 additions & 0 deletions src/training/MVENoDA/MVE_SL_noDA_v1.ipynb

Large diffs are not rendered by default.

Binary file not shown.
Binary file added src/training/MVENoDA/models/mve_noDA_v1_aug7_0100
Binary file not shown.
1 change: 1 addition & 0 deletions src/training/MVENoDA/models/mve_noDA_v1_aug7_0100.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"train_DA_loss": [0.08701178500977033, 0.0869345870515982, 0.08686002830836848, 0.08739230633326815, 0.08779734099314461, 0.08779209613854531, 0.08817189092146849, 0.0896750866372472, 0.08929752625892759, 0.09110028105336744, 0.08917921372234712, 0.08901007318191598, 0.0896472274317275, 0.0887210353309973, 0.08924299177511083, 0.08943420952318791, 0.08997476404665593, 0.08944166615823287, 0.08908946325890107, 0.08879234903990896, 0.08784136199945723, 0.08823276270296718, 0.087440431921639, 0.08781708531776337, 0.08831107427095364, 0.08824166049414622, 0.08787977114537734, 0.08780785028225108, 0.08838255273031796, 0.08795078466535704, 0.08799688524176695, 0.08642718853119087, 0.0870791675430983, 0.08773556487869737, 0.0881936242242073, 0.08702008110186736, 0.08745251344670107, 0.08606364268816367, 0.08656250366726567, 0.08669087337755634, 0.08708776946995114, 0.08653846201257034, 0.08645077173681538, 0.08678987376405509, 0.08644906949027147, 0.08583993137808997, 0.0871460512008048, 0.0857194310813543], "train_regression_loss": [2.588466390831161, 1.1825510290684604, 0.5614199621476901, 0.3815815892254413, 0.3340159839587848, 0.30207329966686325, 0.2660155170197025, 0.22576487543813925, 0.1824270902266232, 0.1445064116303402, 0.11488425209301045, 0.09143752994278014, 0.08345764306062319, 0.07955241897288144, 0.0761071445396548, 0.07261762830685872, 0.06994671807019977, 0.06728829378415396, 0.06471530992326614, 0.061855291256466556, 0.058955402731541304, 0.05620396804997537, 0.05319004204852926, 0.05055610068360347, 0.04963042438915595, 0.04818254581889243, 0.04683273047603789, 0.04595185547958126, 0.04505998927629953, 0.04419220330523943, 0.043325786909575434, 0.04274775103497331, 0.04185574280074244, 0.04108721921237448, 0.04066016430365538, 0.04016399426109939, 0.03915787002980927, 0.03864424340290714, 0.037948356280733286, 0.03779997423604485, 0.037403679486888525, 0.03719191791547609, 0.036498903693593834, 0.036365437976866144, 0.035696833951218254, 0.03576802622043346, 0.03563109608588424, 0.03500147754159608], "train_mve_loss": [1.1354291536036532, 0.44813819077049794, 0.1424143358475956, 0.03656194686687039, -0.00367295232191137, -0.02738715086891648, -0.04407317304542761, -0.05784166720142448, -0.07007532899984767, -0.08195789591145362, -0.07859376340191913, -0.09454077622517808, -0.13180263860787944, -0.1257819156596604, -0.11937972268522112, -0.10642043872305637, -0.13469936284602668, -0.1127016317964254, -0.11196902978627293, -0.11663919079783194, -0.11648614307562438, -0.11970292086929246, -0.11384526539382812, -0.13455337917319818, -0.14930909159686928, -0.11708425542499727, -0.12581275244136178, -0.17171698705832963, -0.14640118695209842, -0.15098256967839202, -0.1771902636755973, -0.13755703565257668, -0.11721230330289609, -0.17824129418383133, -0.13562473717743234, -0.13889788371771303, -0.1528992286050723, -0.1383786405891289, -0.16732195420854706, -0.16286113076729247, -0.1685672777873452, -0.1910746400082787, -0.1425997059097619, -0.21060223062905795, -0.22930107335176939, -0.19832132716528678, -0.20366261088352317, -0.25116531986750024], "train_r2_score": [-6.94498025102126, -2.627822628215485, -0.7139025151671714, -0.16441155419503467, -0.018874912502087247, 0.08040305454380076, 0.18888731957441354, 0.31227805994925756, 0.44339716334037443, 0.5574774548477229, 0.6500614646575191, 0.721114207565642, 0.7454466455714918, 0.7571662829631667, 0.7676074148784228, 0.7780083007831876, 0.7863905975033761, 0.7948377041324136, 0.8019812691847247, 0.8108488621323561, 0.8200904839533054, 0.8279188165168125, 0.8371299834950049, 0.8456284247102861, 0.8483959889032768, 0.8525392631821539, 0.8571484143930086, 0.8593133557231842, 0.8619804935495772, 0.8650785579088476, 0.8671678059420826, 0.869375221463255, 0.8720044004200019, 0.8743780391638268, 0.8754256109245099, 0.8770603924910585, 0.8803473208015773, 0.8817677085047427, 0.8839573556253385, 0.8844636530545846, 0.8857871319206628, 0.8859511547226225, 0.8886492851324217, 0.8884645773229753, 0.8905729432964856, 0.890343802312937, 0.8908875182673991, 0.8930276461466566], "val_source_regression_loss": [1.803157246565517, 0.7873829912535751, 0.4435379014739507, 0.3630698413788518, 0.32702378098723256, 0.29177709131301205, 0.2510951825335056, 0.20779083122180986, 0.16620735409139079, 0.13263118908375124, 0.10412626402287543, 0.08865553018035768, 0.08346782233330267, 0.07960960811263398, 0.07692589973912964, 0.07247834511195557, 0.07121705494915383, 0.0676920822338213, 0.06504997566248043, 0.062015903382738935, 0.058744505425042746, 0.056024806149586846, 0.05418096013555798, 0.051917979517315006, 0.0498113411396176, 0.0488753083244532, 0.04759580910771708, 0.04707503828066814, 0.04647931164201302, 0.0454176637709518, 0.04489489187356792, 0.044488877010873604, 0.043391686239385906, 0.04282742165783538, 0.04166685615347911, 0.04080204943878741, 0.04098410885545272, 0.039759129922412616, 0.03915172171649299, 0.039375986668128, 0.03874065099826342, 0.03902055294830588, 0.03952017652837536, 0.037826039557215536, 0.038223929703235626, 0.038755217307730565, 0.03789755423800855, 0.03665409182917468], "val_target_regression_loss": [1.7730703353881836, 0.7884347348273555, 0.44815569927420795, 0.36494459438173077, 0.33760078930402104, 0.3035271595927733, 0.26389996122710313, 0.22296277197855938, 0.17647109378742265, 0.14242850536409812, 0.11348173173168037, 0.09442131174138829, 0.09183763726791248, 0.08809986399321616, 0.08341789255036583, 0.08099211521352394, 0.07880496134675002, 0.07581153394112104, 0.07139307666051237, 0.07086997153826907, 0.06646393894960609, 0.06310069414822361, 0.05919375930783115, 0.05793009431961971, 0.057256151179346855, 0.05545272017958798, 0.05431444801484482, 0.05357151775609089, 0.052760786247215696, 0.052807340162652955, 0.05055432990664923, 0.04997710518161707, 0.04962258271967308, 0.049235300620711304, 0.04772259771258016, 0.04724956504246102, 0.04610099120041992, 0.04643780297210699, 0.04544762889795666, 0.04515141161465192, 0.04565783828214, 0.04317924689171435, 0.044111377689280085, 0.043544781500402883, 0.043052586432121974, 0.04238546037268412, 0.04256317458009418, 0.04193540207475801], "val_source_r2_score": [-4.912551528043053, -1.408837998980508, -0.34716073840822254, -0.1015982908926856, -0.0013404807959346077, 0.11722713886493696, 0.22392170843968354, 0.3675215659623309, 0.4959231095415311, 0.5832460908450959, 0.6799533555871716, 0.7273613570336311, 0.7465383832364237, 0.7580629842561677, 0.7643088103197836, 0.7801696209789357, 0.7833148991093587, 0.7934569892728398, 0.8024240230059154, 0.8107093311931695, 0.820441565703191, 0.8298271252339029, 0.8360612006996805, 0.8423020404063329, 0.8479168394890242, 0.8511108658104813, 0.8556037802762935, 0.8557144648201636, 0.8561271615352465, 0.8630663688240349, 0.8635872164049986, 0.8613199993581209, 0.8683365458728936, 0.8605975432939288, 0.873281713996555, 0.8749475199149338, 0.8755783524390657, 0.8789385459174848, 0.8804477345366967, 0.8804890730432472, 0.8826849779486272, 0.8808572478494736, 0.8761550823802314, 0.8853355216925657, 0.8820575172219932, 0.8829615267631623, 0.8829217606678287, 0.8868764049373721], "val_target_r2_score": [-4.474305083848151, -1.4430559423125167, -0.38810705417478214, -0.1255622031191915, -0.035202511107907196, 0.05890933598062741, 0.18802581516675113, 0.3141923414671601, 0.45124622579437423, 0.5615591768511826, 0.650812652751277, 0.707930752015543, 0.7143929555905613, 0.7267999441871635, 0.7420594296975732, 0.7504118698535827, 0.7532294957722695, 0.7665710129622714, 0.7804640494318467, 0.782552658665, 0.7963848023926126, 0.8043696907881142, 0.8163968198728367, 0.8205820085048395, 0.8210821343459019, 0.8268064179533244, 0.8327806429621989, 0.8302605794049832, 0.8367949729699183, 0.8367841340993789, 0.8437012584728926, 0.8451009523017751, 0.8467475951367981, 0.8449402200969445, 0.8517483679338245, 0.8516068304326397, 0.8576120135181012, 0.8565132378315, 0.8582330401350724, 0.8599615897130621, 0.8589902354169309, 0.8664189747137506, 0.863617403230646, 0.8657367683237833, 0.8663568897915025, 0.8672086407342573, 0.8663686102253959, 0.8702505327110062], "val_source_mve_loss": [0.7419716708267792, 0.25748599594152427, 0.0762452724377943, 0.01772617343577403, -0.012789547475809349, -0.03307967349486072, -0.049519451547272594, -0.06093021076691302, -0.07282304660051683, -0.07506635733231713, -0.07528842787576627, -0.14220587050990213, -0.11870872342511068, -0.11795557026244417, -0.11080868193243124, -0.15312568003995508, -0.10911494908453542, -0.11706290431792223, -0.09765008280548869, -0.1476523731894131, -0.09508102321172063, -0.10691176736845245, -0.1569025641943835, -0.18159115974661671, -0.11240258322486395, -0.10757324811589869, -0.10880667414469054, -0.17464232689972164, -0.13086032575067086, -0.19407797576505928, -0.14067690253634996, -0.12179087173146537, -0.19575318712976914, -0.1325905701687819, -0.17011750366868852, -0.10627622694908818, -0.14916776549778407, -0.14893645722466178, -0.18796124012221263, -0.12848699903940852, -0.22657080654856526, -0.09847894372253478, -0.1862994068408314, -0.2383544792102862, -0.2019652153683614, -0.16187152804075916, -0.24944465918631492, -0.21795905353147774], "val_target_mve_loss": [0.7271470797212818, 0.2576139336333999, 0.07852521717925615, 0.01918970189010135, -0.006686314303852334, -0.026189812358846015, -0.041710772021095965, -0.05241084012756996, -0.06853783713017084, -0.07164738544180424, -0.07177166801087465, -0.1384166959154455, -0.11386341280952285, -0.11280780049818981, -0.10846277486674394, -0.14900100749881962, -0.10906768882576423, -0.11405321238916132, -0.09714086924361277, -0.13976360140722008, -0.08863411876785604, -0.10590672266634205, -0.15385573536534852, -0.1777240209564378, -0.11279926133118098, -0.10945218886378445, -0.10835677734281443, -0.1578733104316494, -0.12147212014356747, -0.19025641481710387, -0.13935295504199552, -0.11477908911772922, -0.1913861366389673, -0.13051922161933743, -0.15943048630334153, -0.09952073690446117, -0.13465929363818863, -0.13409425497432298, -0.17846692117709148, -0.10717351248935808, -0.22010068350200412, -0.07445514211549034, -0.16137394554252865, -0.23163867854996573, -0.1756224814283697, -0.10103181891049011, -0.2410972457897814, -0.20495026407739783], "da_weight": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "beta": [1.0, 0.992, 0.984, 0.976, 0.968, 0.96, 0.952, 0.944, 0.936, 0.928, 0.92, 0.912, 0.904, 0.896, 0.888, 0.88, 0.872, 0.864, 0.8560000000000001, 0.8480000000000001, 0.8400000000000001, 0.8320000000000001, 0.8240000000000001, 0.8160000000000001, 0.808, 0.8, 0.792, 0.784, 0.776, 0.768, 0.76, 0.752, 0.744, 0.736, 0.728, 0.7200000000000001, 0.7120000000000001, 0.7040000000000001, 0.6960000000000001, 0.6880000000000001, 0.68, 0.6720000000000002, 0.6640000000000001, 0.6560000000000001, 0.6480000000000001, 0.6400000000000001, 0.6320000000000001, 0.6240000000000001], "epoch_no": 48}
Binary file not shown.
Loading

0 comments on commit d5413a7

Please sign in to comment.