Source code for python_utils

import numpy as np
from scipy import stats
from scipy import signal
from scipy.optimize import curve_fit
import glob
import yaml
import re
import matplotlib.pyplot as plt
import pickle


[docs]def save_pkl(obj, path): '''Function saving an object as a pickle file. :param obj: python object (list, dictionary...) to be saved :type obj: generic :param path: path to the object to be saved :type path: string ''' with open(path, 'wb') as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
[docs]def load_pkl(path): '''Function loading an object from a pickle file. :param path: path to the object to be loaded :type path: string ''' with open(path, 'rb') as f: return pickle.load(f)
[docs]def readSpikes(file): '''Function reading spike times in the format produced by the simulation :param file: path to the file with spike times :type file: string ''' l = [] with open(file) as in_file: for line in in_file.readlines(): n_list = [float(i) for i in line.split()] if len(l) <= round(n_list[0]): l.extend( [ [] for i in range( (round(n_list[0]) - len(l) + 1) ) ] ) l[round(n_list[0])].extend(n_list[1:]) for i in range(len(l)): l[i] = np.array(l[i]) return l
[docs]class SpikeSim: '''Class loading and parsing files given by a simulation. The main attributes are the simulation parameters and results: * end_t: end time of simulation * dt: time resolution of the simulation * input_mode: external input mode: - 0 (base mode): each neuron receives an indipendent poisson signal with mean frequency = SubNetwork::ext_in_rate - 2 (paper mode): the input to the striatal population is correlated (ask for details) * rho_corr_paper (only with input_mode 2) * data: dictionary with spike times corresponding to each population; data['pop'] is a list of np.arrays each containing the activity of a neuron * subnets: a list of the SubNetworks in the simulation ''' def __init__(self, path, sim_fname, neglect_t, neglect_t_end=-1, config_fname=''): '''Class constructor: :param path: path to the directory with output simulation files and configuration files :type path: string :param sim_fname: name of the simulation configuration file (inside the directory matched by path) :type sim_fname: string :param neglect_t: time (in ms) to be neglected at the beginning of the simulation :type sim_fname: float :param neglect_t_end: time (in ms) to be neglected at the end of the simulation :type sim_fname: float :param config_fname: name of the subnets_config_yaml configuration file (inside the directory matched by path) :type sim_fname: string ''' self.input_dir = path self.sim_filename = sim_fname self.t_start = neglect_t self.t_end = neglect_t_end self.dt = 0 self.input_mode = 0 self.rho_corr_paper = -1 self.data = dict() self.subnets = [] self.omegas = dict() self.getParameterValues() self.loadData() self.necglectTime() if config_fname!='': with open(self.input_dir + '/' +config_fname) as file: in_dict = yaml.load(file, Loader=yaml.FullLoader) for d in in_dict: self.omegas[d['name']] = d['osc_omega']
[docs] def getParameterValues(self): '''Method initializing the simulation parameters''' with open(self.input_dir + '/' +self.sim_filename) as file: in_dict = yaml.load(file, Loader=yaml.FullLoader) dict_t_end = in_dict['t_end'] if self.t_end < 0.: self.t_end = dict_t_end elif self.t_end > dict_t_end: print(f'ERROR: t_end is too big: max = {dict_t_end}, passed = {self.t_end}') exit() self.dt = in_dict['dt'] self.input_mode = in_dict['input_mode'] if self.input_mode != 0: paper_configfile = in_dict['input_mode_config'] if ( len(paper_configfile.split('/')) !=1 ): paper_configfile = paper_configfile.split('/')[-1] with open(self.input_dir + '/'+paper_configfile) as file_paper: paper_dict = yaml.load(file_paper, Loader=yaml.FullLoader) self.rho_corr_paper = paper_dict['rho_corr']
[docs] def loadData(self): '''Method loading spike times for each SubNetwork''' subnets_files = glob.glob(self.input_dir + '/*_spikes.txt') for f in subnets_files: pop = re.split('/|_', f)[-2] self.data[pop] = readSpikes(f) self.subnets.append(pop) self.subnets = sorted(self.subnets)
[docs] def necglectTime(self): '''Method removing the spikes occurring before t_start and after t_end (if > 0)''' for i in self.data: for j in range( len(self.data[i]) ): if self.t_end < 0.: self.data[i][j] = self.data[i][j][self.data[i][j]>self.t_start] - self.t_start else: self.data[i][j] = self.data[i][j][ np.logical_and( self.data[i][j]>self.t_start, self.data[i][j]<self.t_end ) ] - self.t_start
[docs] def info(self): '''Method printing the simulation parameters''' print('Simulation data from: ' + self.input_dir) print('\t simulation config file: ' + self.sim_filename) print('\t subnets in the network: ', self.subnets) print(f'\t t_start = {self.t_start} ms') print(f'\t t_end = {self.t_end} ms') print(f'\t dt = {self.dt} ms') print(f'\t input_mode = {self.input_mode}') if (self.input_mode != 0): print(f'\t rho_corr = {self.rho_corr_paper}')
def saveData(self, path): save_pkl(self.data, path)
[docs] def histogram(self, pop = '', res=1., save_img=''): '''Method showing or saving the spiking activity of a given subnet :param pop: desidered population; if 'all' is passed all population are showed. :type pop: string :param res: time width of each bin in the histogram :type pop: float :param save_img: path and name of the file to be saved :type save_img: string ''' pop_passed = True if pop == '': pop_passed = False if pop.lower() == 'all': plt.figure() for i,p in enumerate(sorted(self.subnets)): print(p) l = len(self.subnets) cols = 1 if l==1 else 2 rows = round(l/cols) if (l%2==0 or l==1) else int(l/cols)+1 try: plt.subplot(rows,cols, i+1) plt.ylabel(p) plt.xlabel('t [ms]') plt.hist( np.concatenate(self.data[p]), bins=int((self.t_end - self.t_start)/res) ) except Exception as e: print(e) # plt.tight_layout() # plt.savefig(self.input_dir+'/activity.png', dpi=500) plt.show() else: while True: if pop == '': pop = input('histogram: enter subnetwork: ') if pop.lower() == 'stop': break if not pop in self.subnets: print(f'No subnet with name "{pop}", try again...') pop = '' continue plt.hist(np.concatenate(self.data[pop]), bins=int((self.t_end - self.t_start)/res)) plt.title(f'{pop} - {self.rho_corr_paper}') if save_img == '': plt.show() else: plt.savefig(save_img, dpi = 500) plt.close() if pop_passed: break else: pop = ''
[docs] def MeanActivity(self): '''Method computing the mean spiking activity of the subnets''' MAct = [] Ns = [] MActPerN = [] for s in self.subnets: if len(self.data[s]) >0: counts = len( np.concatenate(self.data[s]) ) else: counts = 0 MAct.append( counts/(self.t_end - self.t_start) ) n = len(self.data[s]) Ns.append(n) if n > 0: MActPerN.append(counts/n/(self.t_end - self.t_start)) else: MActPerN.append(0) return {self.subnets[i] : [ MAct[i], Ns[i], MActPerN[i] ] for i in range(len(self.subnets))}
[docs] def periodogram(self, pop='', res=1., N_parseg=500, save_img=''): '''Method computing the periodogram resulting from the (z-scored) spiking activity of the passed subnetwork''' pop_passed = True if pop == '': pop_passed = False if pop.lower() == 'all': plt.figure() for i,p in enumerate(sorted(self.subnets)): print(p) try: plt.subplot(4,2, i+1) x,_ = np.histogram(np.concatenate(self.data[p]), bins = int((self.t_end-self.t_start)/res)) x = stats.zscore(x) fs = 1/res*1000 f, t, Sxx = signal.spectrogram(x, fs, nperseg = N_parseg, noverlap=int(N_parseg/2)) plt.pcolormesh(t, f, Sxx, shading='gouraud') # plt.pcolormesh(t, f, Sxx, shading='auto') plt.ylim(0, 120) plt.colorbar() plt.ylabel(f'f [Hz] {p}') except Exception as e: print(e) # plt.tight_layout() # plt.savefig(self.input_dir+'/activity.png', dpi=500) plt.show() print(f'nparseg = {N_parseg}\tnoverlap={int(N_parseg/2)}') else: while True: if pop == '': pop = input('histogram: enter subnetwork: ') if pop.lower() == 'stop': break if not pop in self.subnets: print(f'No subnet with name "{pop}", try again...') pop = '' continue x,_ = np.histogram(np.concatenate(self.data[pop]), bins = int((self.t_end-self.t_start)/res)) x = stats.zscore(x) fs = 1/res*1000 f, t, Sxx = signal.spectrogram(x, fs, nfft= 10000,nperseg = N_parseg, noverlap=int(N_parseg/2)) plt.pcolormesh(t, f, Sxx, shading='gouraud') plt.ylim(10, 25) plt.colorbar() plt.title(pop) plt.ylabel(f'f [Hz]') plt.xlabel('t [sec]') print(f'nparseg = {N_parseg}\tnoverlap={int(N_parseg/2)}') if save_img == '': plt.show() else: plt.savefig(save_img, dpi = 500) plt.close() if pop_passed: break else: pop = ''
[docs] def welch_spectogram(self, pop='', nparseg=1000, show=True, res=1., save_img='', Ns={}): '''Method computing the spectrogram resulting from the spiking activity of the passed subnetwork using the Welch method''' pop_passed = True if pop == '': pop_passed = False if pop.lower() == 'all': l = len(self.subnets) cols = 1 if l==1 else 2 rows = round(l/cols) if (l%2==0 or l==1) else int(l/cols)+1 to_ret = dict() if show: plt.figure() for i,p in enumerate(sorted(self.subnets)): # print(p) try: x,_ = np.histogram(np.concatenate(self.data[p]), bins = int((self.t_end-self.t_start)/res)) # x = stats.zscore(x) print('not_z_scored') fs = 1/res*1000 if Ns != {}: print('do not zsc!') f, pow_welch_spect = signal.welch(x/Ns[p], fs, nperseg=nparseg, noverlap=int(nparseg/2),nfft=max(30000,nparseg), scaling='density', window='hamming') else: f, pow_welch_spect = signal.welch(x, fs, nperseg=nparseg, noverlap=int(nparseg/2),nfft=max(30000,nparseg), scaling='density', window='hamming') if show: plt.subplot(rows,cols, i+1) plt.plot(f, pow_welch_spect) plt.xlabel(f'f [Hz]') plt.ylabel(f'PSD {p} [u.a.]') plt.xlim(0, 120) # plt.yscale('log') to_ret[p] = pow_welch_spect except Exception as e: print(e) # plt.tight_layout() # plt.savefig(self.input_dir+'/activity.png', dpi=500) if show: plt.show() return (f, to_ret) else: while True: if pop == '': pop = input('welch_spectogram: enter subnetwork: ') if pop.lower() == 'stop': break if not pop in self.subnets: print(f'No subnet with name "{pop}", try again...') pop = '' continue x,_ = np.histogram(np.concatenate(self.data[pop]), bins = int((self.t_end-self.t_start)/res)) # x = stats.zscore(x) print('not_z_scored') fs = 1/res*1000 f, pow_welch_spect = signal.welch(x, fs, nperseg=nparseg, noverlap=int(nparseg/2), nfft=max(30000,nparseg), scaling='density') if show or save_img!='' : plt.plot(f, pow_welch_spect) plt.xlabel(f'f [Hz]') plt.ylabel(f'PSD {pop}') # plt.ylim(0, 120) # plt.yscale('log') np.random.shuffle(x) f, pow_welch_spect = signal.welch(x, fs, nperseg=nparseg, noverlap=int(nparseg/2), scaling='density') plt.plot(f, pow_welch_spect, label='shuffled', color='black', linewidth=0.7) plt.legend() # plt.plot([min(f),max(f)], [0.001]*2) # plt.xlim(0, 300) if save_img == '': if show: plt.show() else: plt.savefig(save_img, dpi = 500) plt.close() if pop_passed: return f, {pop: pow_welch_spect} else: pop = ''
[docs] def activityDistribution(self, pop = '', save_img=''): '''Method computing the distribution of the number of spike of each neuron in the subnetworks''' pop_passed = True if pop == '': pop_passed = False if pop.lower() == 'all': plt.figure() for i,p in enumerate(sorted(self.subnets)): print(p) try: plt.subplot(4,2, i+1) plt.ylabel(p) tmp = [ len(l) for l in self.data[p] ] plt.hist( tmp, bins=range(min(tmp), max(tmp) + 1, 1) ) except Exception as e: print(e) # plt.tight_layout() # plt.savefig(self.input_dir+'/activity.png', dpi=500) plt.show() else: while True: if pop == '': pop = input('activity distribution: enter subnetwork: ') if pop.lower() == 'stop': break if not pop in self.subnets: print(f'No subnet with name "{pop}", try again...') pop = '' continue plt.hist([ len(l) for l in self.data[pop] ]) plt.title(f'{pop} - {self.rho_corr_paper}') if save_img == '': plt.show() else: plt.savefig(save_img, dpi = 500) plt.close() if pop_passed: break else: pop = ''
[docs] @staticmethod def crossCorr(x, y, L, rescale=True): '''Method computing the cross correlation between two vectors mediated over subvectors of len L Notes: * L must be even and less than len(x)/2 * the two vector must be of the same lenght * if rescale is True (default) each subvector is zscored before calculating the cross correlation; otherways only the mean is subtracted to the data ''' print('using decorator') if L%2 == 1: print('ERROR: not even L passed to cross_corr...') exit() if len(x) != len(y): print(f'ERROR: not same lenght arrays passed to cross_corr... {len(x)} vs {len(y)}') exit() M = int(len(x)/L)-1 print(f'convolution calculated with {M} blocks of size {L}') c = np.zeros(L) for i in range(M): start = i*L+int(L/2) # start of x end = (i+1)*L+int(L/2) # end of x if rescale: tmp_y = stats.zscore(y[start-int(L/2):end+int(L/2)]) tmp_x = stats.zscore(x[start-int(L/2):end+int(L/2)]) else: tmp_y = y[start-int(L/2):end+int(L/2)] - y[start-int(L/2):end+int(L/2)].mean() tmp_x = x[start-int(L/2):end+int(L/2)] - x[start-int(L/2):end+int(L/2)].mean() tmp_x = tmp_x[int(L/2):L+int(L/2)] for k in range(L): c[k] += ( tmp_x * tmp_y[k:L+k] ).sum() # plt.plot(np.arange(-int(L/2), int(L/2), 1), np.array(c/(M*L/2))) # plt.plot(np.arange(len(tmp_x)), tmp_x) # plt.plot(np.arange(len(tmp_y)), tmp_y) # plt.show() return (np.arange(-int(L/2), int(L/2), 1), np.array(c/(M*L/2)))
def getAmp(self, x, L, res, omega): def f(t, A): return A*np.cos(omega*t*res) cc, corr = self.crossCorr(x,x,L, False) corr[np.argmax(corr)] = corr[np.argmax(corr[:int(L/2)-2])] pars,covm=curve_fit(f,cc,corr,[1]) plt.plot(cc, corr) plt.plot(cc, f(cc, pars[0])) plt.title(self.input_dir) plt.show() return pars[0], np.sqrt(covm[0][0])