-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_weights.py
More file actions
77 lines (57 loc) · 15.3 KB
/
model_weights.py
File metadata and controls
77 lines (57 loc) · 15.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# -*- coding: utf-8 -*-
"""
Created on Sat Feb 24 09:49:07 2024
@author: Prajit
"""
import os
import re
import torch
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import PPO # or the appropriate algorithm you're using
model_dir = 'tmp2/models'
model_files = [f for f in os.listdir(model_dir) if f.startswith("model_at_step_") and f.endswith(".zip")]
#print(model_files)
# Extract timesteps and sort files by timestep
timesteps = [int(re.findall(r'\d+', f)[0]) for f in model_files]
sorted_files = [file for _, file in sorted(zip(timesteps, model_files), key=lambda x: x[0])]
#print(sorted_files)
def load_model_weights(model_path):
model = PPO.load(model_path)
model_params_dict = model.get_parameters()
all_params_flattened = []
for params in model_params_dict.values():
# Check if params is a dictionary (it should always be, but just in case)
if isinstance(params, dict):
for param in params.values():
# Ensure param is indeed a tensor before processing
if isinstance(param, torch.Tensor):
tensor_numpy = param.cpu().numpy().flatten()
all_params_flattened.append(tensor_numpy)
else:
print(f"Encountered a non-tensor object: {param}")
weights = np.concatenate(all_params_flattened)
return weights
def calculate_l2_norm(weights1, weights2):
norm = sum(np.linalg.norm(w1 - w2) for w1, w2 in zip(weights1, weights2))
return norm
l2_norms = []
for i in range(len(sorted_files) - 1):
model_path_1 = os.path.join(model_dir, sorted_files[i])
model_path_2 = os.path.join(model_dir, sorted_files[i+1])
weights1 = load_model_weights(model_path_1)
weights2 = load_model_weights(model_path_2)
l2_norm = calculate_l2_norm(weights1, weights2)
l2_norms.append(l2_norm)
print(l2_norms)
plt.plot(l2_norms)
plt.xlabel('Model Interval')
plt.ylabel('L2 Norm of Weight Changes')
plt.title('L2 Norm of Weight Changes Between Consecutive Models')
plt.show()
List = [3902.857998023756, 2107.849871025452, 2068.567817578447, 1953.3529669764544, 2044.3008034624022, 2063.8453143755924, 1952.6415828905328, 1959.8634486591493, 2053.042460166489, 1957.2514215017136, 1984.89977414191, 2012.0582876504734, 2062.421051943985, 2004.362064441091, 2036.2028908654972, 2020.5769803137214, 1994.8379911531754, 2023.9523171557137, 1977.1185911262357, 2010.8160052585633, 2061.781220740356, 1971.5150371748216, 1994.174034252053, 2044.9914447375465, 2046.8965593873631, 2019.3622733246193, 1980.7110363183524, 1952.2086466787105, 2038.926940357232, 2012.4979324438448, 1994.3351355556479, 2019.9318701614357, 2025.3468318080688, 2012.1669496545778, 2043.8278140526552, 1982.0133116595744, 2041.7632195297567, 2024.947189868723, 2012.631498670553, 2014.334473401814, 2015.8521777084636, 1961.6805969046466, 2034.3210604059686, 2012.3232666883953, 1989.6877583334326, 2059.782218716631, 2015.583695246059, 2007.2520048815734, 2066.1194225405043, 2011.2036088043626, 1993.4363392109042, 2046.3623315712314, 2000.706566919569, 2043.3109350425411, 2008.7714355229289, 2013.2260748518272, 2039.5298525419844, 2003.284889338086, 2052.668761208855, 2066.3176196092018, 2024.7416573965043, 1991.2803969484962, 2010.8490320992958, 1964.616650312746, 2071.1400788047877, 2022.3992885379585, 2067.391229533081, 2015.5414868808214, 2026.4119199989834, 1997.0320538003907, 2021.065394199328, 2027.8913225252657, 2094.4608431517963, 1999.6963083089943, 2039.250700004026, 2025.9491361855733, 2013.9817633531936, 2040.1222931388652, 2069.9976278469567, 2001.7349079955322, 2045.1932439628458, 2013.3212708183141, 2024.801907493686, 2105.276348940772, 2079.133546163163, 2121.5063491764477, 2072.4810858144524, 2037.124994534066, 2078.0789342567614, 2066.7542368934164, 2020.725277317208, 2057.491729371227, 2083.1790625166213, 2097.747650442311, 2052.999970979265, 2048.597869048999, 2068.0357273615855, 2028.6138821119107, 2046.0216429463792, 2035.7263687959335, 2073.7040687853255, 2068.2463332859047, 2049.0311734945626, 2117.3172367266843, 2021.7583093807916, 2061.610680513105, 2111.2356423966876, 2012.4715427597082, 2051.285534248037, 2048.4907604125083, 2036.5740793907885, 2049.564770733128, 2003.2779591772587, 2054.4343139179427, 2046.8613983568594, 2085.175645611097, 1962.7692652648482, 2057.098537601824, 2043.4223629958171, 2044.0210819547312, 1984.9692678500094, 2004.9204516583557, 2010.6545461920573, 2034.6904196863063, 2016.8419801733517, 1986.2060986148747, 2002.4645837223075, 2035.6234993986593, 1993.2527826586413, 2019.146159578907, 2024.4576684926856, 1995.995902002645, 1984.243941267795, 2004.7120027519595, 2070.800146322721, 1993.598413011494, 2063.168111707413, 2023.3029919081766, 2053.8806386293604, 2016.427650532532, 2016.3410885820172, 1977.26365412816, 2061.173398376893, 2005.9549246132647, 1997.242290696165, 2056.2498011541797, 2006.8265148171631, 2024.6812090366059, 2071.7448084144985, 2065.9463326075625, 2067.5951106517323, 2004.7104017020383, 1995.400035960504, 2031.3327214904057, 2085.7347700462815, 2045.6926479772574, 1938.229150212259, 2035.1819415573336, 2025.8415308021547, 2016.6180274533967, 2043.6196431419226, 2034.31449232899, 2000.4286759222534, 1960.8191111406666, 2016.9957995182294, 2029.3265378106003, 2013.0408890552046, 1946.1716957727606, 2070.977854308103, 1978.3344923436412, 2007.5440913896732, 2028.9481904469214, 2083.1201101178763, 2008.5843227311207, 2026.5289125145994, 1988.3386031704003, 2061.596671957163, 1946.2699066998669, 1975.6532679742195, 2038.1315480282556, 2036.1682085008256, 1976.4778360939135, 2000.8559532350155, 2036.332563019232, 1992.0711601732044, 2001.5402788323252, 2065.450713192869, 2025.7609026249884, 1994.8190329148276, 1986.3167604198898, 2001.679319292889, 2029.906494491304, 2038.9126586840187, 1987.7795205202674, 2074.269025201013, 2014.0370752950523, 2011.196079048047, 2014.8859171209863, 2023.1765804291706, 2022.8770663917815, 2093.666278412625, 1997.8190754042687, 2032.2766529850098, 2053.2347688887735, 2036.5951869530209, 1990.8220444352103, 2026.1922564096965, 1995.9513396573593, 2010.4770031506334, 2069.905717187311, 1975.8598508252162, 2045.8632864898248, 2067.304286564462, 2048.317098204278, 2117.452652919546, 2058.8327507104154, 1987.6390352054384, 2051.4952388510837, 1980.1717643041789, 2044.3971552942498, 2058.220484851221, 2054.330088302581, 2056.201497176678, 2051.859255421968, 2043.074204776367, 2047.7649405699492, 2053.108201579697, 2070.5505205386225, 1982.568594645562, 2035.4270258087035, 2071.5696615297707, 2022.4451876791247, 2081.7085888500496, 2015.9035827923647, 2029.6453211406676, 1992.7733415168072, 2030.3998346972844, 2042.3099452975166, 2019.4690034856926, 2111.984428706645, 2018.4854288173865, 2085.29831637893, 2098.376251372923, 2092.094232606542, 2057.3643632392277, 2005.7852854853368, 2009.3503573076773, 2035.610813263208, 2067.388213634561, 2064.018884696972, 2069.826138657009, 2018.751028567608, 2035.1666442046076, 2027.4814463341836, 2024.4038750909392, 2079.537656881662, 2062.36436536323, 2072.412417956188, 2104.662765368056, 2053.513326308113, 1994.3089865925022, 2028.159217344855, 2074.8986145942067, 2054.7093582704265, 2068.8859064049275, 2075.4129384821135, 2084.1483115756105, 2041.2453063715138, 2036.7896614671151, 2054.210896785963, 2085.5359995351055, 2090.735498501306, 2044.7803129743709, 2081.443186895378, 2089.2348297822014, 2038.5913532173417, 2051.995244650573, 2028.366185200087, 2050.5202396698405, 2085.647316102197, 2102.0913995488327, 2092.3770623057803, 2105.230224165803, 2040.3740395941468, 2114.1387118624402, 2039.6398998222614, 2067.1988634734207, 2058.2259286865424, 2000.775017919914, 2091.5388996839065, 2117.7477685302347, 2020.0516076310628, 2111.0229638409664, 2054.5975469932055, 2079.6202575573056, 2042.9192020887535, 2105.4410068384163, 2076.5561476681332, 2112.7484122838423, 2081.0990712551247, 2062.755362448015, 2082.315785629164, 2039.1784849532064, 2069.5465962023236, 2073.2430658123067, 2105.625137697757, 2066.7743650022257, 2116.491155019, 2087.8318237024687, 2074.98454081598, 2141.4148425712456, 2066.655616850857, 2055.349006284033, 2074.9496011449437, 2053.7325301889205, 2016.9583412690451, 2083.047311599291, 2101.705011876036, 2077.837704737356, 2057.072434613694, 2013.5227503931274, 2077.6657624302093, 2100.601412436412, 2084.085318459899, 2066.1476210082224, 2028.1741996720887, 2110.198771886159, 2108.914680623481, 2127.9113562417615, 2079.6392281194444, 2047.278097472292, 2120.4373038459626, 2038.6400980390545, 2011.4884426981434, 2047.179807371626, 2180.13102209748, 2069.910319779301, 2076.7134022630007, 2038.22770061706, 2077.089221962358, 2133.5797837814766, 2051.5018938398653, 2074.9464643958545, 2111.5022858316365, 2081.1442152584273, 2012.4750696087367, 2120.6031492205316, 2065.0567822289668, 2061.234640393124, 2095.9792608985504, 2067.6960976983723, 2044.042453532554, 2090.9923982787714, 2080.3473761708246, 2093.194048137213, 2124.790910215771, 2071.328685604993, 2064.2590194194763, 2056.285448729377, 2033.6262866017028, 2033.696433705859, 2022.9636759977582, 2056.6975329766406, 2055.7769994646724, 2076.9504456239197, 2082.333252029852, 1992.7085719909194, 2061.532085715433, 2006.1530685434884, 2046.168151364245, 2022.3007358141308, 2142.7096135876222, 2073.0995404169726, 2092.819529754758, 2052.224703965505, 1973.8969743554135, 2055.4215990496496, 2100.246197447794, 2037.103206753098, 2104.324601203649, 2070.0010736873655, 1991.5522008194566, 2040.7537330586474, 2012.3156880179904, 2055.001802477107, 2008.6173733807414, 1995.6913938223488, 1991.0941169245593, 2089.531425765126, 1991.003337127575, 2023.570076646541, 2011.1803715170981, 1998.1238492028112, 2099.620684891709, 2075.0299474176263, 2044.1509227117565, 2015.8700943001122, 2010.7277408140392, 1995.3073677142875, 2065.048640886457, 2053.7275996809108, 2022.2328254699867, 2065.9831814501745, 2089.201237316908, 2041.0125812597275, 2028.8907041805874, 2031.2154639335727, 2049.045406076021, 1995.9782077385635, 2026.62594867943, 2018.9001324912042, 2047.1203824212598, 1967.931892821518, 2076.7457926058555, 2036.0431643629934, 2035.345723730483, 2057.256146312149, 1972.2367297757874, 2071.2521914831505, 2010.3861465229302, 2015.6245457599666, 2042.868103059814, 2036.1133461206632, 2050.6694826142366, 2050.7953416137093, 2041.5170500289632, 2065.9578275620775, 2030.5790968643728, 2045.2274982746987, 2038.2505021269417, 2024.3488827739634, 2055.037126012218, 2026.3318304822626, 2024.2273591117473, 2023.2421608253194, 2068.964277424009, 2104.6326334092596, 2022.65542729973, 2061.4528330135877, 2030.0354166764985, 2030.081055721122, 2038.2644042083157, 2083.33694987433, 2066.0324702903436, 2030.7024171413232, 2034.6157294737488, 1981.94719767065, 2052.1004256808446, 2015.9760828764893, 2082.8310044837117, 2048.8405116846616, 2006.3499361910026, 2038.5759340687987, 2083.849118893435, 2058.4027563366535, 2036.4751012774973, 2058.962960024091, 2022.164601926498, 2065.9215068092462, 2012.2518052938044, 2083.027636312866, 2055.970327415047, 2035.801802682069, 1997.7767098231411, 2095.9785688492116, 2064.4117722336355, 2072.24271711317, 2090.535080224142, 2056.870494371855, 2022.3768932284281, 2054.5605774310975, 2007.548676010154, 2052.4340913830547, 2042.028126779611, 2095.454795326379, 2053.419753901621, 2044.5165355055747, 2032.304064573113, 2063.069688198691, 2075.1041943394616, 2014.8915300521244, 2032.7133288795649, 2026.4119051225482, 2035.1443103433016, 2006.4954534143526, 1988.7620771744648, 2025.5678328146937, 2007.0758053461498, 2040.8481081957195, 2020.14805908804, 2083.3541791902408, 2072.3525476353343, 2061.498098279475, 2029.9842331041257, 1954.0470462924118, 2023.34318149161, 2066.2921896113385, 1985.377710650962, 1965.8071751307011, 2010.8032930493277, 1984.4409500233314, 2010.3824754204472, 2046.5939246779303, 2116.1117263602414, 2025.3208108315207, 2022.330751448114, 2017.08679380151, 2008.2868698469392, 2063.3266618181906, 2028.241986395396, 2069.8695199301596, 2004.604753137046, 2045.5411295311044, 2065.5674737250993, 2003.0230421826564, 2041.1186799239667, 2044.115509950489, 1995.3955238099625, 2008.051018267388, 2083.0423333666904, 2060.581578880377, 2053.876334057541, 2040.7697516939174, 1986.2212405149019, 2041.895966587124, 2019.3456481684127, 2087.760575740902, 2010.853187481583, 2005.028157731672, 2058.3185653132996, 1998.232853891192, 2063.5995106020973, 2082.8278994033512, 2084.544454684703, 2135.4735040113746, 2059.4066223298364, 2044.1161659009053, 2079.398828936907, 2031.4243804474602, 2014.6853461777994, 2077.3950464452037, 2100.0615634183428, 2010.5957718446239, 2051.7388761928573, 1994.1583813415132, 1940.4325061828115, 1967.3220345285974, 2012.9469731013226, 2036.4079944563005, 2038.9803299160772, 2014.4385274422673, 2032.3663950310838, 2007.3559863610576, 2029.5094414652492, 2032.1862211810494, 2033.182636093722, 2077.104914657349, 2096.2361058763317, 2065.5090218233927, 2005.4726484577388, 2060.415949559681, 2078.649238285935, 2010.1187098181795, 2049.678257642243, 2025.3738685540197, 2029.3584310871204, 1978.1840949930286, 2008.2069579054235, 2021.753786902448, 2000.0216239798629, 2048.8396433561247, 2008.509105215321, 2025.1587218193117, 2017.964890538134, 1946.4996672688526, 2049.700010009189, 2061.6454873466273, 2050.3690591633494, 1968.8028274273208, 2020.2028942422637, 2059.7911759425824, 2034.1532864432902, 1963.7160185144446, 2076.0312256982565, 2068.5833033921767, 2007.614367206903, 2089.148261268234, 2035.905121977763, 2096.383041174677, 2024.9395316032108, 2077.250720994704, 1888.2246468047215, 2029.437591269257, 2011.6317829368502, 2076.068328536938, 2003.3649370430846, 2030.3492927024884, 1985.5437665127617, 1965.3472638202745, 2004.3149870947982, 2052.02906055593, 2043.3527122461478, 2049.8782307965557, 2006.0066789036805, 2020.3538045527775, 2012.7131740064094, 1944.5105143607957, 1996.2809758703377, 2062.81439575239, 2020.9846404638165, 1998.3379368866756, 1941.8092956390944, 2023.0099019412148, 2019.9029711872672, 2121.9499909568526, 2017.007326898352, 1924.5279268723893, 2006.9710527535526, 2045.2826906001258, 1976.042147426073, 2083.471727033826, 2052.435270712424, 2027.5165852358537, 2100.210563025822, 2005.7777175835192, 2039.5066807093247, 1993.3882166392225, 2064.32599831227, 1997.889809203243, 2042.1174220403434, 1990.260028296462, 1977.177281088592, 2005.8492798360471, 2074.7095999084418, 2016.8712521276418, 1922.6057212020694, 2076.780955136157, 2026.023874954024, 2045.2837078992945, 1960.3478138410287, 2041.0349358879698, 1962.041792887498, 2012.886744782314, 1976.1368171306176, 2041.109303957208, 1991.415546625444, 2016.580513406719, 2031.5135270956284, 1984.8235001890619, 2012.6343463483809, 2123.647823683608, 2058.7339669482626, 1940.6027702606727, 2016.7872610969146, 2010.3146151850704, 2069.179585297161, 2065.3021221841586, 1965.0747064850834, 2007.2949308507996, 1966.623211109267, 2047.558604210089, 2002.3035471740097, 1981.5913056127877, 1958.5891677001657, 2103.892345439214, 2041.438787971817, 2014.1925444704189, 1969.8150180177386, 2024.4711212709176, 2006.2721048636995, 1990.5389455623072, 1988.8454315918966, 2062.6947735397234, 1908.292162258633, 2034.4565577067688, 1985.1854511448032, 2014.0626958117107, 2021.9965481320703, 1999.9873719534608, 2081.787391231464, 1920.9994196447512, 2045.9944983411183, 1993.172665421727, 2080.3492941172603]
plt.plot(List[10:680])
plt.xlabel('Model Interval')
plt.ylabel('L2 Norm of Weight Changes')
plt.title('L2 Norm of Weight Changes Between Consecutive Models')
plt.show()