Skip to content

Commit

Permalink
periodic isomap saves
Browse files Browse the repository at this point in the history
  • Loading branch information
ShrihanSolo committed Jul 18, 2024
1 parent 9bdcb68 commit 40e2873
Show file tree
Hide file tree
Showing 11 changed files with 352 additions and 65 deletions.
327 changes: 263 additions & 64 deletions training/notebooks/MMD_paper/multiband/ShrihanPaperMMD_mb_isomap.ipynb

Large diffs are not rendered by default.

Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"train_DA_loss": [0.08378414178209233, 0.08198375869167789, 0.08236151533048103, 0.08080680694744345, 0.08211931031388, 0.08275454559821424, 0.08287133068745271, 0.08142151561662746, 0.08085973944406023, 0.08445909800317812, 0.07996224794346138, 0.08086163974392621, 0.08069792221482482, 0.08055360312701307, 0.08050990269487039, 0.08096892103872907, 0.08183459029916054, 0.08243224966310823, 0.08199209916035916, 0.0819357118264549], "train_regression_loss": [0.009391380612094848, 0.009178896025980958, 0.009190522177933189, 0.00886911467284686, 0.008736877544898833, 0.00887472045112408, 0.008756233715721264, 0.008615461066177955, 0.008514590073382003, 0.008546264787243723, 0.008310145841080566, 0.008331553514202346, 0.008277214531934866, 0.008203912652787197, 0.008180848222568126, 0.0082110068028694, 0.008150385180871314, 0.008147232494337426, 0.008167800685080957, 0.00806364893655325], "train_r2_score": [0.9809248639696921, 0.9814300608737279, 0.9813707330784385, 0.9819899415684537, 0.982278221728043, 0.9819762203916629, 0.9822324644699466, 0.9826174888862467, 0.9826122665917509, 0.9826032708504356, 0.9831028088647872, 0.9831567424124741, 0.9831296660656652, 0.9833171290982256, 0.9833672756245709, 0.9833173826972216, 0.9833793038018882, 0.9834479960127092, 0.9834181009616261, 0.9837264749869145], "val_source_regression_loss": [0.009441302943927277, 0.008904439521728048, 0.010961075968874297, 0.009462258072605558, 0.009030804524471046, 0.008518964227194049, 0.009033691132097108, 0.008367692336890917, 0.009589782336157314, 0.008646954416006708, 0.008666613438777672, 0.008876111082566581, 0.008612159918400513, 0.008388509637243144, 0.008123774820001452, 0.008458099452553281, 0.0086224014503039, 0.008317165847891455, 0.007972991211150008, 0.008082788582938682], "val_target_regression_loss": [0.03445017693718527, 0.033362161344403674, 0.032156644774612726, 0.03259899389520762, 0.03166891436335767, 0.03140123908046135, 0.030953935130979795, 0.030326274460905297, 0.02962588339118631, 0.02884125809191139, 0.029115843823903306, 0.02935293631236644, 0.027676375989155594, 0.027145135518946466, 0.026570827895953397, 0.02740418129440421, 0.027217535794398208, 0.02657648388913293, 0.025029130063500184, 0.02599543741174564], "val_source_r2_score": [0.9804536489131856, 0.9817020160517522, 0.9771000587144241, 0.980064917478021, 0.9811985359545702, 0.9821682616180019, 0.9812798065133974, 0.9828130111966193, 0.9799517683027008, 0.9817499849917694, 0.9819567899101118, 0.9817006716725657, 0.9823408579426796, 0.9823631453597982, 0.9830729691675761, 0.9824007476185875, 0.9821727616122419, 0.9827857262887272, 0.9833852269065878, 0.9833497832149051], "val_target_r2_score": [0.9313153153737631, 0.933765386297871, 0.9346639260284142, 0.9342031663390251, 0.9365122089842283, 0.9370807821951758, 0.9372733093506636, 0.9385424821569882, 0.9392392276137543, 0.9424741068985133, 0.9416604112403434, 0.9416402568048481, 0.9441250226534156, 0.9455187100757452, 0.9466920099929211, 0.9452665097291194, 0.9456600247799319, 0.9460795065029521, 0.9501646547760318, 0.948105812516281], "epoch_no": 50}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"train_DA_loss": [0.25634037361393464, 0.1529867263405262, 0.12254512749980093, 0.11340398337671591, 0.10617097410998244, 0.10283422358061649, 0.0982255144197991, 0.09740277223452462, 0.09529518759664714, 0.09495379370203075, 0.09419018542308912, 0.09062781555567898, 0.08903767743660393, 0.09005740223966634, 0.08607240709079467, 0.08797339139884199], "train_regression_loss": [0.336621464260767, 0.05377123775164949, 0.037821270600369904, 0.02981724985066454, 0.02558601138708088, 0.023034274393787958, 0.02089970356059453, 0.01898107022537392, 0.01791435308665909, 0.01660537798088759, 0.015859879824457183, 0.015232678193776055, 0.013787721795365351, 0.013194258755559895, 0.013083366719517223, 0.012690169518810086], "train_r2_score": [0.3156164437602291, 0.8913548572515785, 0.9235116061159699, 0.9397671627422692, 0.9481836251726783, 0.9534256806677776, 0.9577760069661724, 0.9614602799138312, 0.9636021248974841, 0.9663575225957239, 0.9678829367019922, 0.9692269646050838, 0.9720265641663174, 0.9732796611555271, 0.9734600884686375, 0.9742547156417153], "val_source_regression_loss": [0.0663455261190416, 0.04191677266387803, 0.031834045573357186, 0.025973013430169434, 0.023940730828103746, 0.02078595564101532, 0.01946775425984791, 0.017391526712709743, 0.017003588995356467, 0.016124629842675035, 0.014520380204057048, 0.01376971259570805, 0.013096123999044014, 0.012525101011369829, 0.012574149171115866, 0.012029730492407919], "val_target_regression_loss": [0.3956533555582071, 0.19946468806570503, 0.11999595925733922, 0.09607967842299088, 0.09200644936815948, 0.08044255209291816, 0.07603770340229296, 0.07313325435255363, 0.07401863837934983, 0.06409259590136397, 0.059685280226218475, 0.05767358818156704, 0.05554958660701278, 0.05511639780916606, 0.05489507218478781, 0.050551147170507224], "val_source_r2_score": [0.8642934020842643, 0.911934746781441, 0.934198850730005, 0.9465030037015932, 0.9507924509036877, 0.9574360503739061, 0.9600621735796862, 0.9643973938572106, 0.9646598834222077, 0.9664420915914836, 0.9704427980690716, 0.9717252969536749, 0.9733185482209324, 0.9738902047320419, 0.9739782398580092, 0.9751454540686537], "val_target_r2_score": [0.1852557538841319, 0.5927222938650744, 0.760441590839082, 0.8071971737169826, 0.813441127592621, 0.8369550461874634, 0.8491618199597255, 0.8519016232139623, 0.8489128245295819, 0.8717248626131479, 0.8796438672977519, 0.8845297950291477, 0.8871027691167279, 0.8884507857081992, 0.8903678163698389, 0.8985657420296078], "epoch_no": 16}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"train_DA_loss": [0.1822167321140631, 0.11226800804995204, 0.09932632561452973, 0.09665221010608908, 0.09196142268654851, 0.09012669546894007, 0.08700468379261833, 0.08734025711183466, 0.08576066949331564, 0.08647498808515137, 0.08602151169460369, 0.08214884036828439, 0.08003148613276277, 0.08393759378512991, 0.08360784630575498, 0.0797428227075711, 0.08110615582813184, 0.08059697185338252, 0.08000697373621969, 0.08039855190860942, 0.07775339476723339, 0.07813146184292201, 0.07902205301111588, 0.07921343478604034, 0.07659977944857242, 0.07823911087904159, 0.07922007030444782, 0.09076427604655787, 0.13119789195235076, 0.14799516253050765], "train_regression_loss": [0.15798097231018424, 0.03061670491427619, 0.024216612857567245, 0.020289248663428114, 0.017746864535404843, 0.016200108495345117, 0.014442755257152362, 0.013129382165680065, 0.012111047545866786, 0.011353759948075689, 0.011007407698456805, 0.0107376593594155, 0.010041277339746682, 0.009579623598389164, 0.009508036070850625, 0.009392577380970312, 0.00915768886133143, 0.00908215786841182, 0.009161424446830584, 0.009118012754268726, 0.009048333746113796, 0.00935784660912267, 0.009804502756324626, 0.010485295998656158, 0.011591486372751689, 0.013030515128569559, 0.015271760994939635, 0.01977164234405884, 0.02739578155354619, 0.03182332557362166], "train_r2_score": [0.6814056945700145, 0.9381384587147804, 0.9509657173194936, 0.9588725833873933, 0.9638580841572095, 0.9672244107982432, 0.970815472763787, 0.9732968343951174, 0.975326385573006, 0.9769099395955135, 0.9776771784228029, 0.9783046956138662, 0.9796344422765053, 0.9806549111431212, 0.9807469982770779, 0.9809892552332318, 0.9813382940400243, 0.9814902321854917, 0.9813814795751896, 0.9815614424877345, 0.9816449888646142, 0.9810400298611178, 0.9800735172405315, 0.9787402003615421, 0.9765578355395444, 0.9735644319887561, 0.9691167172370773, 0.9600142794607922, 0.9450418575945152, 0.9362136536770308], "val_source_regression_loss": [0.035686239232398144, 0.02518023810924808, 0.021067639136580146, 0.017990550898300234, 0.016817859979049794, 0.015305314551160973, 0.012830608094312773, 0.011692757519543361, 0.011188466491913246, 0.011173695518641145, 0.010096299497612343, 0.009860927968695286, 0.010086293544347404, 0.009231846310957602, 0.010745680929535324, 0.009409839419387044, 0.010065671400564491, 0.010786521174369535, 0.009057218463984645, 0.008696894062042331, 0.011340422638852125, 0.009510888762558532, 0.010026964741598839, 0.011170657095054437, 0.013786663158921299, 0.013475929033960317, 0.01574741205153097, 0.020364916112250202, 0.029807959867131178, 0.031858164650999055], "val_target_regression_loss": [0.1358682672214356, 0.0910508253131133, 0.0799491315676718, 0.06665341260659087, 0.06299033355276296, 0.049992994897684474, 0.0510943352488006, 0.04696251351125301, 0.04277730217072994, 0.04064102950179653, 0.040577148223758505, 0.036529210686778564, 0.036534304534838455, 0.03178030846248956, 0.033461145694800624, 0.029489581686723384, 0.03045877426044102, 0.030858389161242422, 0.027390287133159153, 0.027629629910514233, 0.02786909778762585, 0.026394053771617305, 0.025690048630497637, 0.028870397158394193, 0.028121928370017913, 0.036313307038538016, 0.0382306297772392, 0.042842514326523064, 0.05332454519975147, 0.05248734529373372], "val_source_r2_score": [0.9271528938595088, 0.9471158405688828, 0.9561120589275367, 0.9626300419397941, 0.9653985776736688, 0.9686291980478279, 0.9736472560416612, 0.9759744240452036, 0.9767691867666636, 0.9759516153380151, 0.9793548690964416, 0.9797135143305178, 0.9787106447030857, 0.9809553534420365, 0.9778185685765423, 0.9805170814244919, 0.9790972637841723, 0.9779365399540123, 0.981070052648232, 0.9820770670287963, 0.9767371138705749, 0.9803070065344857, 0.9794586940132954, 0.9769376382503888, 0.9712074128001379, 0.9720144851975109, 0.9672822314343404, 0.9581839902991478, 0.9384927532724194, 0.934083059414067], "val_target_r2_score": [0.7210121503955959, 0.814706076379305, 0.8404279179082621, 0.8664522620004899, 0.8721636569523055, 0.8989183071638734, 0.8989902913167167, 0.9059525834858984, 0.9137708366837115, 0.9191371653269825, 0.9181791163001786, 0.927445464738762, 0.926490483296046, 0.9362431614754219, 0.9321395563952083, 0.9409945992209153, 0.9381162421027971, 0.9387283090746618, 0.9448446017971687, 0.9453357164245029, 0.9442351818533671, 0.9472376190587193, 0.9485558832894643, 0.9408648657121568, 0.9431900031350122, 0.9263560661055724, 0.9232398790905888, 0.913806119699367, 0.8904669031715455, 0.895645167595449]}
87 changes: 86 additions & 1 deletion training/notebooks/MMD_paper/multiband/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,4 +461,89 @@ def generate_isomaps(source_data, target_data, model, n_neighbors = 5, n_compone
source_iso = isomap.transform(sdata)
target_iso = isomap.transform(tdata)

return source_iso, target_iso, trained_source_iso, trained_target_iso
return source_iso, target_iso, trained_source_iso, trained_target_iso

def show_isomaps(source_iso,
target_iso,
trained_source_iso,
trained_target_iso,
source_labels,
target_labels,
mod_name,
epoch_no,
pretrain_lim = 500,
posttrain_lim = 50,
save = False):

fig0, axes = plt.subplots(1, 2, figsize=(8, 4))

(ax1, ax2) = axes
ax1.scatter(source_iso[:, 0], source_iso[:, 1], s=3, marker='o')
ax1.scatter(target_iso[:, 0], target_iso[:, 1], s=3, marker='^')
lval1 = pretrain_lim
ax1.set_xlim(-lval1, lval1)
ax1.set_ylim(-lval1, lval1)
ax1.set_title('Source and Target')

ax2.scatter(trained_source_iso[:, 0], trained_source_iso[:, 1], s=3, marker='o')
ax2.scatter(trained_target_iso[:, 0], trained_target_iso[:, 1], s=3, marker='^')
lval2 = posttrain_lim
ax2.set_xlim(-lval2, lval2)
ax2.set_ylim(-lval2, lval2)
ax2.set_title('Trained Source and Target')

ax1.set_xlabel('Component 1')
ax1.set_ylabel('Component 2')
ax2.set_xlabel('Component 1')
ax2.set_ylabel('Component 2')

if save:
plt.savefig(mod_name + "_" + str(epoch_no) + "_compare.png", bbox_inches = 'tight', dpi = 400)

plt.show()

fig1, ax = plt.subplots(2, 2, figsize=(14, 10))

ax1 = ax[0][1]
scatter1 = ax1.scatter(trained_source_iso[:, 0], trained_source_iso[:, 1], s=3, marker='o', c = source_labels)
lval1 = posttrain_lim
ax1.set_xlim(-lval1, lval1)
ax1.set_ylim(-lval1, lval1)
ax1.set_title('Trained Source')

ax2 = ax[0][0]
ax2.scatter(source_iso[:, 0], source_iso[:, 1], s=3, c = source_labels)
lval2 = pretrain_lim
ax2.set_xlim(-lval2, lval2)
ax2.set_ylim(-lval2, lval2)
ax2.set_title('Source')

ax1 = ax[1][1]
ax1.scatter(trained_target_iso[:, 0], trained_target_iso[:, 1], s=3, marker='o', c = target_labels)
lval1 = posttrain_lim
ax1.set_xlim(-lval1, lval1)
ax1.set_ylim(-lval1, lval1)
ax1.set_title('Trained Target')

ax2 = ax[1][0]
ax2.scatter(target_iso[:, 0], target_iso[:, 1], s=3, c = target_labels)
lval2 = pretrain_lim
ax2.set_xlim(-lval2, lval2)
ax2.set_ylim(-lval2, lval2)
ax2.set_title('Target')

for i in ax.ravel():
i.set_xlabel('Component 1')
i.set_ylabel('Component 2')

cbar = fig1.colorbar(scatter1, ax=ax.ravel().tolist(), orientation='vertical')
cbar.set_label('$\\theta_E$')

plt.suptitle("Isomap of Regression Inputs: Before and After", x = 0.44, y = 0.94, fontsize = 20)

if save:
plt.savefig(mod_name + "_" + str(epoch_no) + "_thetaE.png", bbox_inches = 'tight', dpi = 400)

plt.show()

return fig0, axes, fig1, ax

0 comments on commit 40e2873

Please sign in to comment.