import numpy as np
from scipy import stats
from scipy import signal
from scipy.optimize import curve_fit
import glob
import yaml
import re
import matplotlib.pyplot as plt
import pickle
[docs]def save_pkl(obj, path):
'''Function saving an object as a pickle file.
:param obj: python object (list, dictionary...) to be saved
:type obj: generic
:param path: path to the object to be saved
:type path: string
'''
with open(path, 'wb') as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
[docs]def load_pkl(path):
'''Function loading an object from a pickle file.
:param path: path to the object to be loaded
:type path: string
'''
with open(path, 'rb') as f:
return pickle.load(f)
[docs]def readSpikes(file):
'''Function reading spike times in the format produced by the simulation
:param file: path to the file with spike times
:type file: string
'''
l = []
with open(file) as in_file:
for line in in_file.readlines():
n_list = [float(i) for i in line.split()]
if len(l) <= round(n_list[0]):
l.extend( [ [] for i in range( (round(n_list[0]) - len(l) + 1) ) ] )
l[round(n_list[0])].extend(n_list[1:])
for i in range(len(l)):
l[i] = np.array(l[i])
return l
[docs]class SpikeSim:
'''Class loading and parsing files given by a simulation.
The main attributes are the simulation parameters and results:
* end_t: end time of simulation
* dt: time resolution of the simulation
* input_mode: external input mode:
- 0 (base mode): each neuron receives an indipendent poisson signal with mean frequency = SubNetwork::ext_in_rate
- 2 (paper mode): the input to the striatal population is correlated (ask for details)
* rho_corr_paper (only with input_mode 2)
* data: dictionary with spike times corresponding to each population; data['pop'] is a list of np.arrays each containing the activity of a neuron
* subnets: a list of the SubNetworks in the simulation
'''
def __init__(self, path, sim_fname, neglect_t, neglect_t_end=-1, config_fname=''):
'''Class constructor:
:param path: path to the directory with output simulation files and configuration files
:type path: string
:param sim_fname: name of the simulation configuration file (inside the directory matched by path)
:type sim_fname: string
:param neglect_t: time (in ms) to be neglected at the beginning of the simulation
:type sim_fname: float
:param neglect_t_end: time (in ms) to be neglected at the end of the simulation
:type sim_fname: float
:param config_fname: name of the subnets_config_yaml configuration file (inside the directory matched by path)
:type sim_fname: string
'''
self.input_dir = path
self.sim_filename = sim_fname
self.t_start = neglect_t
self.t_end = neglect_t_end
self.dt = 0
self.input_mode = 0
self.rho_corr_paper = -1
self.data = dict()
self.subnets = []
self.omegas = dict()
self.getParameterValues()
self.loadData()
self.necglectTime()
if config_fname!='':
with open(self.input_dir + '/' +config_fname) as file:
in_dict = yaml.load(file, Loader=yaml.FullLoader)
for d in in_dict:
self.omegas[d['name']] = d['osc_omega']
[docs] def getParameterValues(self):
'''Method initializing the simulation parameters'''
with open(self.input_dir + '/' +self.sim_filename) as file:
in_dict = yaml.load(file, Loader=yaml.FullLoader)
dict_t_end = in_dict['t_end']
if self.t_end < 0.:
self.t_end = dict_t_end
elif self.t_end > dict_t_end:
print(f'ERROR: t_end is too big: max = {dict_t_end}, passed = {self.t_end}')
exit()
self.dt = in_dict['dt']
self.input_mode = in_dict['input_mode']
if self.input_mode != 0:
paper_configfile = in_dict['input_mode_config']
if ( len(paper_configfile.split('/')) !=1 ):
paper_configfile = paper_configfile.split('/')[-1]
with open(self.input_dir + '/'+paper_configfile) as file_paper:
paper_dict = yaml.load(file_paper, Loader=yaml.FullLoader)
self.rho_corr_paper = paper_dict['rho_corr']
[docs] def loadData(self):
'''Method loading spike times for each SubNetwork'''
subnets_files = glob.glob(self.input_dir + '/*_spikes.txt')
for f in subnets_files:
pop = re.split('/|_', f)[-2]
self.data[pop] = readSpikes(f)
self.subnets.append(pop)
self.subnets = sorted(self.subnets)
[docs] def necglectTime(self):
'''Method removing the spikes occurring before t_start and after t_end (if > 0)'''
for i in self.data:
for j in range( len(self.data[i]) ):
if self.t_end < 0.:
self.data[i][j] = self.data[i][j][self.data[i][j]>self.t_start] - self.t_start
else:
self.data[i][j] = self.data[i][j][ np.logical_and( self.data[i][j]>self.t_start, self.data[i][j]<self.t_end ) ] - self.t_start
[docs] def info(self):
'''Method printing the simulation parameters'''
print('Simulation data from: ' + self.input_dir)
print('\t simulation config file: ' + self.sim_filename)
print('\t subnets in the network: ', self.subnets)
print(f'\t t_start = {self.t_start} ms')
print(f'\t t_end = {self.t_end} ms')
print(f'\t dt = {self.dt} ms')
print(f'\t input_mode = {self.input_mode}')
if (self.input_mode != 0):
print(f'\t rho_corr = {self.rho_corr_paper}')
def saveData(self, path):
save_pkl(self.data, path)
[docs] def histogram(self, pop = '', res=1., save_img=''):
'''Method showing or saving the spiking activity of a given subnet
:param pop: desidered population; if 'all' is passed all population are showed.
:type pop: string
:param res: time width of each bin in the histogram
:type pop: float
:param save_img: path and name of the file to be saved
:type save_img: string
'''
pop_passed = True
if pop == '':
pop_passed = False
if pop.lower() == 'all':
plt.figure()
for i,p in enumerate(sorted(self.subnets)):
print(p)
l = len(self.subnets)
cols = 1 if l==1 else 2
rows = round(l/cols) if (l%2==0 or l==1) else int(l/cols)+1
try:
plt.subplot(rows,cols, i+1)
plt.ylabel(p)
plt.xlabel('t [ms]')
plt.hist( np.concatenate(self.data[p]), bins=int((self.t_end - self.t_start)/res) )
except Exception as e: print(e)
# plt.tight_layout()
# plt.savefig(self.input_dir+'/activity.png', dpi=500)
plt.show()
else:
while True:
if pop == '':
pop = input('histogram: enter subnetwork: ')
if pop.lower() == 'stop':
break
if not pop in self.subnets:
print(f'No subnet with name "{pop}", try again...')
pop = ''
continue
plt.hist(np.concatenate(self.data[pop]), bins=int((self.t_end - self.t_start)/res))
plt.title(f'{pop} - {self.rho_corr_paper}')
if save_img == '':
plt.show()
else:
plt.savefig(save_img, dpi = 500)
plt.close()
if pop_passed: break
else: pop = ''
[docs] def MeanActivity(self):
'''Method computing the mean spiking activity of the subnets'''
MAct = []
Ns = []
MActPerN = []
for s in self.subnets:
if len(self.data[s]) >0:
counts = len( np.concatenate(self.data[s]) )
else: counts = 0
MAct.append( counts/(self.t_end - self.t_start) )
n = len(self.data[s])
Ns.append(n)
if n > 0:
MActPerN.append(counts/n/(self.t_end - self.t_start))
else:
MActPerN.append(0)
return {self.subnets[i] : [ MAct[i], Ns[i], MActPerN[i] ] for i in range(len(self.subnets))}
[docs] def periodogram(self, pop='', res=1., N_parseg=500, save_img=''):
'''Method computing the periodogram resulting from the (z-scored) spiking activity of the passed subnetwork'''
pop_passed = True
if pop == '':
pop_passed = False
if pop.lower() == 'all':
plt.figure()
for i,p in enumerate(sorted(self.subnets)):
print(p)
try:
plt.subplot(4,2, i+1)
x,_ = np.histogram(np.concatenate(self.data[p]), bins = int((self.t_end-self.t_start)/res))
x = stats.zscore(x)
fs = 1/res*1000
f, t, Sxx = signal.spectrogram(x, fs, nperseg = N_parseg, noverlap=int(N_parseg/2))
plt.pcolormesh(t, f, Sxx, shading='gouraud')
# plt.pcolormesh(t, f, Sxx, shading='auto')
plt.ylim(0, 120)
plt.colorbar()
plt.ylabel(f'f [Hz] {p}')
except Exception as e: print(e)
# plt.tight_layout()
# plt.savefig(self.input_dir+'/activity.png', dpi=500)
plt.show()
print(f'nparseg = {N_parseg}\tnoverlap={int(N_parseg/2)}')
else:
while True:
if pop == '':
pop = input('histogram: enter subnetwork: ')
if pop.lower() == 'stop':
break
if not pop in self.subnets:
print(f'No subnet with name "{pop}", try again...')
pop = ''
continue
x,_ = np.histogram(np.concatenate(self.data[pop]), bins = int((self.t_end-self.t_start)/res))
x = stats.zscore(x)
fs = 1/res*1000
f, t, Sxx = signal.spectrogram(x, fs, nfft= 10000,nperseg = N_parseg, noverlap=int(N_parseg/2))
plt.pcolormesh(t, f, Sxx, shading='gouraud')
plt.ylim(10, 25)
plt.colorbar()
plt.title(pop)
plt.ylabel(f'f [Hz]')
plt.xlabel('t [sec]')
print(f'nparseg = {N_parseg}\tnoverlap={int(N_parseg/2)}')
if save_img == '':
plt.show()
else:
plt.savefig(save_img, dpi = 500)
plt.close()
if pop_passed: break
else: pop = ''
[docs] def welch_spectogram(self, pop='', nparseg=1000, show=True, res=1., save_img='', Ns={}):
'''Method computing the spectrogram resulting from the spiking activity of the passed subnetwork using the Welch method'''
pop_passed = True
if pop == '':
pop_passed = False
if pop.lower() == 'all':
l = len(self.subnets)
cols = 1 if l==1 else 2
rows = round(l/cols) if (l%2==0 or l==1) else int(l/cols)+1
to_ret = dict()
if show:
plt.figure()
for i,p in enumerate(sorted(self.subnets)):
# print(p)
try:
x,_ = np.histogram(np.concatenate(self.data[p]), bins = int((self.t_end-self.t_start)/res))
# x = stats.zscore(x)
print('not_z_scored')
fs = 1/res*1000
if Ns != {}:
print('do not zsc!')
f, pow_welch_spect = signal.welch(x/Ns[p], fs, nperseg=nparseg, noverlap=int(nparseg/2),nfft=max(30000,nparseg), scaling='density', window='hamming')
else:
f, pow_welch_spect = signal.welch(x, fs, nperseg=nparseg, noverlap=int(nparseg/2),nfft=max(30000,nparseg), scaling='density', window='hamming')
if show:
plt.subplot(rows,cols, i+1)
plt.plot(f, pow_welch_spect)
plt.xlabel(f'f [Hz]')
plt.ylabel(f'PSD {p} [u.a.]')
plt.xlim(0, 120)
# plt.yscale('log')
to_ret[p] = pow_welch_spect
except Exception as e: print(e)
# plt.tight_layout()
# plt.savefig(self.input_dir+'/activity.png', dpi=500)
if show:
plt.show()
return (f, to_ret)
else:
while True:
if pop == '':
pop = input('welch_spectogram: enter subnetwork: ')
if pop.lower() == 'stop':
break
if not pop in self.subnets:
print(f'No subnet with name "{pop}", try again...')
pop = ''
continue
x,_ = np.histogram(np.concatenate(self.data[pop]), bins = int((self.t_end-self.t_start)/res))
# x = stats.zscore(x)
print('not_z_scored')
fs = 1/res*1000
f, pow_welch_spect = signal.welch(x, fs, nperseg=nparseg, noverlap=int(nparseg/2), nfft=max(30000,nparseg), scaling='density')
if show or save_img!='' :
plt.plot(f, pow_welch_spect)
plt.xlabel(f'f [Hz]')
plt.ylabel(f'PSD {pop}')
# plt.ylim(0, 120)
# plt.yscale('log')
np.random.shuffle(x)
f, pow_welch_spect = signal.welch(x, fs, nperseg=nparseg, noverlap=int(nparseg/2), scaling='density')
plt.plot(f, pow_welch_spect, label='shuffled', color='black', linewidth=0.7)
plt.legend()
# plt.plot([min(f),max(f)], [0.001]*2)
# plt.xlim(0, 300)
if save_img == '':
if show:
plt.show()
else:
plt.savefig(save_img, dpi = 500)
plt.close()
if pop_passed: return f, {pop: pow_welch_spect}
else: pop = ''
[docs] def activityDistribution(self, pop = '', save_img=''):
'''Method computing the distribution of the number of spike of each neuron in the subnetworks'''
pop_passed = True
if pop == '':
pop_passed = False
if pop.lower() == 'all':
plt.figure()
for i,p in enumerate(sorted(self.subnets)):
print(p)
try:
plt.subplot(4,2, i+1)
plt.ylabel(p)
tmp = [ len(l) for l in self.data[p] ]
plt.hist( tmp, bins=range(min(tmp), max(tmp) + 1, 1) )
except Exception as e: print(e)
# plt.tight_layout()
# plt.savefig(self.input_dir+'/activity.png', dpi=500)
plt.show()
else:
while True:
if pop == '':
pop = input('activity distribution: enter subnetwork: ')
if pop.lower() == 'stop':
break
if not pop in self.subnets:
print(f'No subnet with name "{pop}", try again...')
pop = ''
continue
plt.hist([ len(l) for l in self.data[pop] ])
plt.title(f'{pop} - {self.rho_corr_paper}')
if save_img == '':
plt.show()
else:
plt.savefig(save_img, dpi = 500)
plt.close()
if pop_passed: break
else: pop = ''
[docs] @staticmethod
def crossCorr(x, y, L, rescale=True):
'''Method computing the cross correlation between two vectors mediated over subvectors of len L
Notes:
* L must be even and less than len(x)/2
* the two vector must be of the same lenght
* if rescale is True (default) each subvector is zscored before calculating the cross correlation; otherways only the mean is subtracted to the data
'''
print('using decorator')
if L%2 == 1:
print('ERROR: not even L passed to cross_corr...')
exit()
if len(x) != len(y):
print(f'ERROR: not same lenght arrays passed to cross_corr... {len(x)} vs {len(y)}')
exit()
M = int(len(x)/L)-1
print(f'convolution calculated with {M} blocks of size {L}')
c = np.zeros(L)
for i in range(M):
start = i*L+int(L/2) # start of x
end = (i+1)*L+int(L/2) # end of x
if rescale:
tmp_y = stats.zscore(y[start-int(L/2):end+int(L/2)])
tmp_x = stats.zscore(x[start-int(L/2):end+int(L/2)])
else:
tmp_y = y[start-int(L/2):end+int(L/2)] - y[start-int(L/2):end+int(L/2)].mean()
tmp_x = x[start-int(L/2):end+int(L/2)] - x[start-int(L/2):end+int(L/2)].mean()
tmp_x = tmp_x[int(L/2):L+int(L/2)]
for k in range(L):
c[k] += ( tmp_x * tmp_y[k:L+k] ).sum()
# plt.plot(np.arange(-int(L/2), int(L/2), 1), np.array(c/(M*L/2)))
# plt.plot(np.arange(len(tmp_x)), tmp_x)
# plt.plot(np.arange(len(tmp_y)), tmp_y)
# plt.show()
return (np.arange(-int(L/2), int(L/2), 1), np.array(c/(M*L/2)))
def getAmp(self, x, L, res, omega):
def f(t, A):
return A*np.cos(omega*t*res)
cc, corr = self.crossCorr(x,x,L, False)
corr[np.argmax(corr)] = corr[np.argmax(corr[:int(L/2)-2])]
pars,covm=curve_fit(f,cc,corr,[1])
plt.plot(cc, corr)
plt.plot(cc, f(cc, pars[0]))
plt.title(self.input_dir)
plt.show()
return pars[0], np.sqrt(covm[0][0])